Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
35e5c532
Commit
35e5c532
authored
Jun 19, 2023
by
aska-0096
Browse files
clang format
parent
b010b095
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
28 deletions
+50
-28
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+50
-28
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
35e5c532
...
@@ -252,7 +252,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -252,7 +252,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1Spec
,
B1Spec
,
CSpec
>
;
CSpec
>
;
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
...
@@ -260,7 +261,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -260,7 +261,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number
<
AK1
>
{});
Number
<
AK1
>
{});
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_lengths_vec
,
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_strides_vec
)
{
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
...
@@ -268,8 +270,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -268,8 +270,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number
<
BK1
>
{});
Number
<
BK1
>
{});
}
}
static
auto
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
MakeB1GridDescriptor_BK0_N_BK1
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_strides_vec
)
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_strides_vec
)
{
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
...
@@ -457,10 +459,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -457,10 +459,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_lengths
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>&
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD1Tensor
>&
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD1Tensor
>&
...
@@ -846,21 +852,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -846,21 +852,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_lengths_
{};
// c_gs_ms_os_lengths
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_lengths_
{};
// c_gs_ms_os_lengths
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_strides_
{};
// c_gs_ms_os_strides
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_strides_
{};
// c_gs_ms_os_strides
std
::
copy
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
begin
()
+
dimension
,
a_gs_ms_ks_lengths_
.
begin
());
std
::
copy
(
a_gs_ms_ks_lengths
.
begin
(),
std
::
copy
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
begin
()
+
dimension
,
a_gs_ms_ks_strides_
.
begin
());
a_gs_ms_ks_lengths
.
begin
()
+
dimension
,
std
::
copy
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
begin
()
+
dimension
,
b_gs_ns_ks_lengths_
.
begin
());
a_gs_ms_ks_lengths_
.
begin
());
std
::
copy
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
begin
()
+
dimension
,
b_gs_ns_ks_strides_
.
begin
());
std
::
copy
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
begin
()
+
dimension
,
a_gs_ms_ks_strides_
.
begin
());
std
::
copy
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
begin
()
+
dimension
,
b_gs_ns_ks_lengths_
.
begin
());
std
::
copy
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
begin
()
+
dimension
,
b_gs_ns_ks_strides_
.
begin
());
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
(),
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
(),
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_lengths_
.
begin
());
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_lengths_
.
begin
());
// b1_gs_os_ns_lengths
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_strides
.
begin
(),
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_strides
.
begin
(),
b1_gs_gemm1ns_gemm1ks_strides
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_strides
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_strides_
.
begin
());
// b1_gs_os_ns_strides
b1_gs_gemm1ns_gemm1ks_strides_
.
begin
());
// b1_gs_os_ns_strides
std
::
copy
(
c_gs_ms_gemm1ns_lengths
.
begin
(),
std
::
copy
(
c_gs_ms_gemm1ns_lengths
.
begin
(),
c_gs_ms_gemm1ns_lengths
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_lengths
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_lengths_
.
begin
());
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths_
.
begin
());
// c_gs_ms_os_lengths
std
::
copy
(
c_gs_ms_gemm1ns_strides
.
begin
(),
std
::
copy
(
c_gs_ms_gemm1ns_strides
.
begin
(),
c_gs_ms_gemm1ns_strides
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_strides
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_strides_
.
begin
());
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides_
.
begin
());
// c_gs_ms_os_strides
return
Argument
{
p_a
,
return
Argument
{
p_a
,
...
@@ -930,21 +944,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -930,21 +944,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_lengths_
{};
// c_gs_ms_os_lengths
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_lengths_
{};
// c_gs_ms_os_lengths
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_strides_
{};
// c_gs_ms_os_strides
std
::
array
<
index_t
,
dimension
>
c_gs_ms_gemm1ns_strides_
{};
// c_gs_ms_os_strides
std
::
copy
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
begin
()
+
dimension
,
a_gs_ms_ks_lengths_
.
begin
());
std
::
copy
(
a_gs_ms_ks_lengths
.
begin
(),
std
::
copy
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
begin
()
+
dimension
,
a_gs_ms_ks_strides_
.
begin
());
a_gs_ms_ks_lengths
.
begin
()
+
dimension
,
std
::
copy
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
begin
()
+
dimension
,
b_gs_ns_ks_lengths_
.
begin
());
a_gs_ms_ks_lengths_
.
begin
());
std
::
copy
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
begin
()
+
dimension
,
b_gs_ns_ks_strides_
.
begin
());
std
::
copy
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
begin
()
+
dimension
,
a_gs_ms_ks_strides_
.
begin
());
std
::
copy
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
begin
()
+
dimension
,
b_gs_ns_ks_lengths_
.
begin
());
std
::
copy
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
begin
()
+
dimension
,
b_gs_ns_ks_strides_
.
begin
());
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
(),
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
(),
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_lengths
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_lengths_
.
begin
());
// b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_lengths_
.
begin
());
// b1_gs_os_ns_lengths
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_strides
.
begin
(),
std
::
copy
(
b1_gs_gemm1ns_gemm1ks_strides
.
begin
(),
b1_gs_gemm1ns_gemm1ks_strides
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_strides
.
begin
()
+
dimension
,
b1_gs_gemm1ns_gemm1ks_strides_
.
begin
());
// b1_gs_os_ns_strides
b1_gs_gemm1ns_gemm1ks_strides_
.
begin
());
// b1_gs_os_ns_strides
std
::
copy
(
c_gs_ms_gemm1ns_lengths
.
begin
(),
std
::
copy
(
c_gs_ms_gemm1ns_lengths
.
begin
(),
c_gs_ms_gemm1ns_lengths
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_lengths
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_lengths_
.
begin
());
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths_
.
begin
());
// c_gs_ms_os_lengths
std
::
copy
(
c_gs_ms_gemm1ns_strides
.
begin
(),
std
::
copy
(
c_gs_ms_gemm1ns_strides
.
begin
(),
c_gs_ms_gemm1ns_strides
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_strides
.
begin
()
+
dimension
,
c_gs_ms_gemm1ns_strides_
.
begin
());
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides_
.
begin
());
// c_gs_ms_os_strides
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment