Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b637c77d
Commit
b637c77d
authored
Dec 19, 2022
by
Anthony Chang
Browse files
format
parent
8dad40d0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
45 deletions
+45
-45
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+42
-42
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+3
-3
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
b637c77d
...
@@ -340,28 +340,28 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -340,28 +340,28 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
());
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
());
return
SpaceFillingCurve
<
return
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
{};
// SnakeCurved
false
>
{};
// SnakeCurved
}
}
__host__
__device__
static
constexpr
auto
MakeCThreadIndexAdaptor8DTo2D
()
__host__
__device__
static
constexpr
auto
MakeCThreadIndexAdaptor8DTo2D
()
{
{
if
constexpr
(
TransposeC
)
if
constexpr
(
TransposeC
)
{
{
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
n3
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
n3
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
n4
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
n4
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
)),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
,
n3
,
n4
))),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
,
n3
,
n4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
return
thread_idx_to_m_n_adaptor
;
return
thread_idx_to_m_n_adaptor
;
...
@@ -369,17 +369,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -369,17 +369,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
else
else
{
{
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
m3
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
m3
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
m4
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
m4
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
,
m3
,
m4
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
,
m3
,
m4
)),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
))),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
return
thread_idx_to_m_n_adaptor
;
return
thread_idx_to_m_n_adaptor
;
...
@@ -1002,20 +1002,20 @@ struct BlockwiseGemmXdlops_v2
...
@@ -1002,20 +1002,20 @@ struct BlockwiseGemmXdlops_v2
__host__
__device__
static
constexpr
auto
MakeCThreadIndexAdaptor8DTo2D
()
__host__
__device__
static
constexpr
auto
MakeCThreadIndexAdaptor8DTo2D
()
{
{
if
constexpr
(
TransposeC
)
if
constexpr
(
TransposeC
)
{
{
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
n3
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
n3
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
n4
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
n4
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
)),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
,
n3
,
n4
))),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
,
n3
,
n4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
return
thread_idx_to_m_n_adaptor
;
return
thread_idx_to_m_n_adaptor
;
...
@@ -1023,17 +1023,17 @@ struct BlockwiseGemmXdlops_v2
...
@@ -1023,17 +1023,17 @@ struct BlockwiseGemmXdlops_v2
else
else
{
{
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_thread_desc
=
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
m0
=
c_thread_desc
.
GetLength
(
Number
<
0
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
n0
=
c_thread_desc
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
m1
=
c_thread_desc
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
n1
=
c_thread_desc
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
m2
=
c_thread_desc
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
m3
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
m3
=
c_thread_desc
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
m4
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
m4
=
c_thread_desc
.
GetLength
(
Number
<
6
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
n2
=
c_thread_desc
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
thread_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
,
m3
,
m4
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
m0
,
m1
,
m2
,
m3
,
m4
)),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
))),
make_unmerge_transform
(
make_tuple
(
n0
,
n1
,
n2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
return
thread_idx_to_m_n_adaptor
;
return
thread_idx_to_m_n_adaptor
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
b637c77d
...
@@ -487,7 +487,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -487,7 +487,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
M1
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I2
);
const
auto
M1
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I2
);
constexpr
auto
Y_O1
=
AK1
;
constexpr
auto
Y_O1
=
AK1
;
const
auto
Y_O0
=
O
/
Y_O1
;
const
auto
Y_O0
=
O
/
Y_O1
;
const
auto
ygrad_grid_desc_o0_m_o1
=
transform_tensor_descriptor
(
const
auto
ygrad_grid_desc_o0_m_o1
=
transform_tensor_descriptor
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_desc_m0_o_m1
,
...
@@ -508,7 +508,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -508,7 +508,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
N1
=
v_grid_desc_n0_o_n1
.
GetLength
(
I2
);
const
auto
N1
=
v_grid_desc_n0_o_n1
.
GetLength
(
I2
);
constexpr
auto
V_O1
=
BK1
;
constexpr
auto
V_O1
=
BK1
;
const
auto
V_O0
=
O
/
V_O1
;
const
auto
V_O0
=
O
/
V_O1
;
const
auto
v_grid_desc_o0_n_o1
=
transform_tensor_descriptor
(
const
auto
v_grid_desc_o0_n_o1
=
transform_tensor_descriptor
(
v_grid_desc_n0_o_n1
,
v_grid_desc_n0_o_n1
,
...
@@ -1414,7 +1414,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1414,7 +1414,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
pgrad_blockwise_gemm
=
typename
PGradGemmTile_M_N_O
::
BlockwiseGemm
{};
auto
pgrad_blockwise_gemm
=
typename
PGradGemmTile_M_N_O
::
BlockwiseGemm
{};
auto
pgrad_thread_buf
=
pgrad_blockwise_gemm
.
GetCThreadBuffer
();
auto
pgrad_thread_buf
=
pgrad_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
pgrad_gemm_tile_ygrad_block_reset_copy_step
=
const
auto
pgrad_gemm_tile_ygrad_block_reset_copy_step
=
make_multi_index
(
-
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
),
0
,
0
);
make_multi_index
(
-
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
),
0
,
0
);
const
auto
pgrad_gemm_tile_v_block_reset_copy_step
=
const
auto
pgrad_gemm_tile_v_block_reset_copy_step
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment