Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f1b2e521
Commit
f1b2e521
authored
Dec 07, 2022
by
Anthony Chang
Browse files
format
parent
4ae9919e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
67 deletions
+67
-67
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+15
-8
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+51
-58
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
f1b2e521
...
...
@@ -15,7 +15,7 @@ Outputs:
*/
#pragma clang diagnostic ignored "-Wunused-variable"
// TODO ANT: remove
#pragma clang diagnostic ignored "-Wunused-variable" // TODO ANT: remove
#define PRINT_HOST 0
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
f1b2e521
...
...
@@ -361,10 +361,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
.
second
;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ", v_gs_os_ns_strides_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec,
// ",") << std::endl; LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ",
// v_gs_os_ns_strides_vec, ",") << std::endl; LogRangeAsType<float>(std::cout <<
// "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec,
// ",") << std::endl;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
...
...
@@ -685,7 +687,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
{
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
...
...
@@ -726,9 +732,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
// c_grid_desc_g_m_n_.Print();
std
::
cout
<<
"vgrad_grid_desc_n_o_: "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I0
)
<<
", "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I1
)
<<
'\n'
;
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
std
::
cout
<<
"vgrad_grid_desc_n_o_: "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I0
)
<<
", "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I1
)
<<
'\n'
;
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
f1b2e521
...
...
@@ -415,13 +415,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__
__device__
static
constexpr
auto
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
const
LSEGridDesc_M
&
lse_grid_desc_m
)
{
const
index_t
M
=
lse_grid_desc_m
.
GetLength
(
I0
);
const
index_t
MBlock
=
M
/
MPerBlock
;
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
const
index_t
M
=
lse_grid_desc_m
.
GetLength
(
I0
);
const
index_t
MBlock
=
M
/
MPerBlock
;
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
const
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
transform_tensor_descriptor
(
lse_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MXdlPerWave
>
{},
MWave
,
Number
<
MPerXdl
>
{}))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MXdlPerWave
>
{},
MWave
,
Number
<
MPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}));
...
...
@@ -469,10 +470,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
p_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
p_block_space_offset
=
0
;
static
constexpr
auto
ygrad_block_space_offset
=
p_block_space_size_aligned
.
value
;
// LDS allocation for reduction
...
...
@@ -1015,18 +1016,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
p_thread_origin_nd_idx_on_block
[
I0
],
p_thread_origin_nd_idx_on_block
[
I1
],
p_thread_origin_nd_idx_on_block
[
I2
]
%
p_block_slice_lengths_m0_n0_m1_n1
[
I2
],
p_thread_origin_nd_idx_on_block
[
I3
]
%
p_block_slice_lengths_m0_n0_m1_n1
[
I3
],
p_thread_origin_nd_idx_on_block
[
I4
],
p_thread_origin_nd_idx_on_block
[
I5
],
p_thread_origin_nd_idx_on_block
[
I6
],
p_thread_origin_nd_idx_on_block
[
I7
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
true
>
{
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
p_thread_origin_nd_idx_on_block
[
I0
],
p_thread_origin_nd_idx_on_block
[
I1
],
p_thread_origin_nd_idx_on_block
[
I2
]
%
p_block_slice_lengths_m0_n0_m1_n1
[
I2
],
p_thread_origin_nd_idx_on_block
[
I3
]
%
p_block_slice_lengths_m0_n0_m1_n1
[
I3
],
p_thread_origin_nd_idx_on_block
[
I4
],
p_thread_origin_nd_idx_on_block
[
I5
],
p_thread_origin_nd_idx_on_block
[
I6
],
p_thread_origin_nd_idx_on_block
[
I7
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
// Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0],
// p_block_slice_lengths_m0_n0_m1_n1[I1],
...
...
@@ -1087,18 +1087,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
ygrad_block_space_offset
,
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
p_block_desc_m0_n_m1
),
decltype
(
ygrad_block_desc_m0_o_m1
),
MPerXdl
,
NPerXdl
,
VGradGemmTile_N_O_M
::
GemmNRepeat
,
VGradGemmTile_N_O_M
::
GemmORepeat
,
VGradGemmTile_N_O_M
::
GemmMPack
,
true
>
{};
// TranspossC
auto
vgrad_blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
p_block_desc_m0_n_m1
),
decltype
(
ygrad_block_desc_m0_o_m1
),
MPerXdl
,
NPerXdl
,
VGradGemmTile_N_O_M
::
GemmNRepeat
,
VGradGemmTile_N_O_M
::
GemmORepeat
,
VGradGemmTile_N_O_M
::
GemmMPack
,
true
>
{};
// TranspossC
auto
vgrad_acc_thread_buf
=
vgrad_blockwise_gemm
.
GetCThreadBuffer
();
...
...
@@ -1107,12 +1107,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
=
transform_tensor_descriptor
(
vgrad_grid_desc_n_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
VGradGemmTile_N_O_M
::
GemmNWave
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
I1
,
VGradGemmTile_N_O_M
::
GemmOWave
,
NPerXdl
))),
make_unmerge_transform
(
make_tuple
(
I1
,
VGradGemmTile_N_O_M
::
GemmNWave
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
I1
,
VGradGemmTile_N_O_M
::
GemmOWave
,
NPerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
...
...
@@ -1142,12 +1138,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
index_t
o_thread_data_idx_on_grid
=
vgrad_thread_mtx_on_block_n_o
[
I1
]
+
gemm1_n_block_data_idx_on_grid
;
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
VGrad_N0
,
VGrad_N1
,
VGrad_N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
VGrad_N0
,
VGrad_N1
,
VGrad_N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_nd_idx_on_grid
=
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
...
...
@@ -1171,9 +1165,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
decltype
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
),
tensor_operation
::
element_wise
::
PassThrough
,
// CElementwiseOperation
decltype
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLengths
()),
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
7
,
// VectorDim
2
,
// ScalarPerVector
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
7
,
// VectorDim
2
,
// ScalarPerVector
InMemoryDataOperationEnum
::
AtomicAdd
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
true
>
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
...
...
@@ -1226,10 +1220,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
index_t
num_vgrad_gemm_loop
=
MPerBlock
/
VGradGemmTile_N_O_M
::
Sum_M
;
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
lse_grid_buf
,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
lse_thread_buf
);
lse_grid_buf
,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
lse_thread_buf
);
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
...
...
@@ -1366,10 +1360,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// load VGrad Gemm A
const
auto
p_nd_idx
=
sfc_p_m0_n0_m1_n1_m2_n2
.
GetIndexTupleOfNumber
(
vgrad_gemm_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
p_nd_idx
[
I2
],
p_nd_idx
[
I2
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I2
]);
constexpr
auto
nwave_range
=
make_tuple
(
p_nd_idx
[
I3
],
p_nd_idx
[
I3
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I3
]);
constexpr
auto
mwave_range
=
make_tuple
(
p_nd_idx
[
I2
],
p_nd_idx
[
I2
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I2
]);
constexpr
auto
nwave_range
=
make_tuple
(
p_nd_idx
[
I3
],
p_nd_idx
[
I3
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I3
]);
#if 0
if(hipThreadIdx_x % 64 == 0)
{
...
...
@@ -1385,12 +1379,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range));
}
#endif
if
(
p_thread_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
if
(
p_thread_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
p_thread_copy_vgpr_to_lds
.
Run
(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
p_nd_idx
[
I0
],
p_nd_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
p_nd_idx
[
I0
],
p_nd_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
acc_thread_buf
,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
p_block_buf
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment