Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f1b2e521
Commit
f1b2e521
authored
Dec 07, 2022
by
Anthony Chang
Browse files
format
parent
4ae9919e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
67 deletions
+67
-67
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+15
-8
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+51
-58
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
f1b2e521
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
f1b2e521
...
@@ -361,10 +361,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -361,10 +361,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
.
second
;
.
second
;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec,
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ", v_gs_os_ns_strides_vec, ",") << std::endl;
// ",") << std::endl; LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ",
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// v_gs_os_ns_strides_vec, ",") << std::endl; LogRangeAsType<float>(std::cout <<
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec, ",") << std::endl;
// "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec,
// ",") << std::endl;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
...
@@ -685,7 +687,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -685,7 +687,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
{
{
// TODO ANT: implement bias addition
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc0_biases
;
...
@@ -726,9 +732,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -726,9 +732,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
// c_grid_desc_g_m_n_.Print();
// c_grid_desc_g_m_n_.Print();
std
::
cout
<<
"vgrad_grid_desc_n_o_: "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I0
)
<<
", "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I1
)
<<
'\n'
;
std
::
cout
<<
"vgrad_grid_desc_n_o_: "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
<<
", "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I1
)
<<
'\n'
;
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
f1b2e521
...
@@ -421,7 +421,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -421,7 +421,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
transform_tensor_descriptor
(
const
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
transform_tensor_descriptor
(
lse_grid_desc_m
,
lse_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MXdlPerWave
>
{},
MWave
,
Number
<
MPerXdl
>
{}))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MXdlPerWave
>
{},
MWave
,
Number
<
MPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}));
...
@@ -1015,8 +1016,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1015,8 +1016,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
1
,
// DstScalarPerVector
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
true
>
{
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
make_multi_index
(
p_thread_origin_nd_idx_on_block
[
I0
],
p_thread_origin_nd_idx_on_block
[
I0
],
p_thread_origin_nd_idx_on_block
[
I1
],
p_thread_origin_nd_idx_on_block
[
I1
],
...
@@ -1087,8 +1087,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1087,8 +1087,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
ygrad_block_space_offset
,
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
ygrad_block_space_offset
,
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
());
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
auto
vgrad_blockwise_gemm
=
BlockSize
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
p_block_desc_m0_n_m1
),
decltype
(
p_block_desc_m0_n_m1
),
...
@@ -1107,12 +1107,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1107,12 +1107,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
=
transform_tensor_descriptor
(
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
=
transform_tensor_descriptor
(
vgrad_grid_desc_n_o
,
vgrad_grid_desc_n_o
,
make_tuple
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
make_unmerge_transform
(
make_tuple
(
I1
,
VGradGemmTile_N_O_M
::
GemmNWave
,
MPerXdl
)),
VGradGemmTile_N_O_M
::
GemmNWave
,
make_unmerge_transform
(
make_tuple
(
I1
,
VGradGemmTile_N_O_M
::
GemmOWave
,
NPerXdl
))),
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
I1
,
VGradGemmTile_N_O_M
::
GemmOWave
,
NPerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
...
@@ -1142,10 +1138,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1142,10 +1138,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
index_t
o_thread_data_idx_on_grid
=
const
index_t
o_thread_data_idx_on_grid
=
vgrad_thread_mtx_on_block_n_o
[
I1
]
+
gemm1_n_block_data_idx_on_grid
;
vgrad_thread_mtx_on_block_n_o
[
I1
]
+
gemm1_n_block_data_idx_on_grid
;
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
VGrad_N0
,
VGrad_N1
,
VGrad_N2
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
VGrad_N0
,
VGrad_N1
,
VGrad_N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
...
@@ -1366,10 +1360,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1366,10 +1360,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// load VGrad Gemm A
// load VGrad Gemm A
const
auto
p_nd_idx
=
const
auto
p_nd_idx
=
sfc_p_m0_n0_m1_n1_m2_n2
.
GetIndexTupleOfNumber
(
vgrad_gemm_loop_idx
);
sfc_p_m0_n0_m1_n1_m2_n2
.
GetIndexTupleOfNumber
(
vgrad_gemm_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
constexpr
auto
mwave_range
=
p_nd_idx
[
I2
],
p_nd_idx
[
I2
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I2
]);
make_tuple
(
p_nd_idx
[
I2
],
p_nd_idx
[
I2
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I2
]);
constexpr
auto
nwave_range
=
make_tuple
(
constexpr
auto
nwave_range
=
p_nd_idx
[
I3
],
p_nd_idx
[
I3
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I3
]);
make_tuple
(
p_nd_idx
[
I3
],
p_nd_idx
[
I3
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I3
]);
#if 0
#if 0
if(hipThreadIdx_x % 64 == 0)
if(hipThreadIdx_x % 64 == 0)
{
{
...
@@ -1385,12 +1379,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1385,12 +1379,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range));
p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range));
}
}
#endif
#endif
if
(
p_thread_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
if
(
p_thread_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
{
p_thread_copy_vgpr_to_lds
.
Run
(
p_thread_copy_vgpr_to_lds
.
Run
(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_tuple
(
p_nd_idx
[
I0
],
p_nd_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
p_nd_idx
[
I0
],
p_nd_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
acc_thread_buf
,
acc_thread_buf
,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
p_block_buf
);
p_block_buf
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment