Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
326e9ee8
Commit
326e9ee8
authored
Mar 02, 2023
by
danyao12
Browse files
support bwd fp16&bf16
parent
32b03f33
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
214 additions
and
59 deletions
+214
-59
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+9
-5
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1.cpp
...softmax_gemm/batched_multihead_attention_backward_pt1.cpp
+23
-11
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+43
-41
include/ck/utility/generic_memory_space_atomic.hpp
include/ck/utility/generic_memory_space_atomic.hpp
+135
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
326e9ee8
...
...
@@ -10,7 +10,7 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_
add_example_executable
(
example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt1
_fp16
batched_multihead_attention_backward_pt1
_fp16
.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt1 batched_multihead_attention_backward_pt1.cpp
)
add_example_executable
(
example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
326e9ee8
...
...
@@ -50,9 +50,10 @@ Kernel outputs:
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
...
@@ -748,9 +749,12 @@ int run(int argc, char* argv[])
{
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ygrad_g_m_o
(
idx_gmo
)
*
y_g_m_o
(
idx_gmo
);
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
}
self
(
idx_gmn
)
=
p_g_m_n
(
idx_gmn
)
*
(
pgrad_g_m_n
(
idx_gmn
)
-
ygrad_dot_y
);
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
#if PRINT_HOST
{
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1
_fp16
.cpp
→
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1.cpp
View file @
326e9ee8
...
...
@@ -49,9 +49,10 @@ Kernel outputs:
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
...
@@ -59,7 +60,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
F16
;
using
DataType
=
BF16
;
using
GemmDataType
=
BF16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
...
...
@@ -101,6 +103,7 @@ using DeviceGemmInstance =
NumDimK
,
NumDimO
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
...
...
@@ -169,6 +172,7 @@ using DeviceGemmInstance =
NumDimK
,
NumDimO
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
...
...
@@ -340,16 +344,21 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
// 512
ck
::
index_t
N
=
512
;
// 512
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
M
=
512
;
// 512
ck
::
index_t
N
=
512
;
// 512
#if USING_HD32
ck
::
index_t
K
=
32
;
// K/O<=32
ck
::
index_t
O
=
32
;
#else
ck
::
index_t
K
=
64
;
// 32<K/O<=64
ck
::
index_t
O
=
64
;
#endif
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G1
=
6
;
// 16
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
true
;
//
false;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.2
;
...
...
@@ -747,9 +756,12 @@ int run(int argc, char* argv[])
{
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ygrad_g_m_o
(
idx_gmo
)
*
y_g_m_o
(
idx_gmo
);
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
}
self
(
idx_gmn
)
=
p_g_m_n
(
idx_gmn
)
*
(
pgrad_g_m_n
(
idx_gmn
)
-
ygrad_dot_y
);
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
#if PRINT_HOST
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
326e9ee8
...
...
@@ -173,6 +173,7 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
...
...
@@ -598,9 +599,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
DataType
,
// TODO: distinguish A/B datatype
LSE
DataType
,
Gemm
DataType
,
GemmAccDataType
,
CShuffleDataType
,
LSEDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
326e9ee8
...
...
@@ -21,6 +21,7 @@
namespace
ck
{
template
<
typename
DataType
,
typename
GemmDataType
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatLSE
,
...
...
@@ -121,7 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
...
...
@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
Gemm
DataType
,
GridDesc_K0_M_K1
,
decltype
(
q_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
Gemm
DataType
,
GridDesc_K0_N_K1
,
decltype
(
k_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
Gemm
DataType
,
GridDesc_K0_N_K1
,
decltype
(
v_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
Gemm
DataType
,
GridDesc_K0_M_K1
,
decltype
(
ygrad_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -506,13 +507,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockDesc_BK0_N_BK1
{});
}
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
Gemm
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -587,7 +589,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
DataType
,
Gemm
DataType
,
decltype
(
a_src_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -610,7 +612,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
static
constexpr
index_t
GemmMWave
=
Gemm0MWaves
;
static
constexpr
index_t
GemmNWave
=
Gemm0NWaves
;
...
...
@@ -676,8 +678,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static
constexpr
auto
b_thread_desc_k0_n_k1
=
MakeBThreadDesc_K0_N_K1
();
using
BBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
ThreadwiseTensorSliceTransfer_v2
<
Gemm
DataType
,
Gemm
DataType
,
decltype
(
b_block_desc_n0_n1_n2_k0_k1_k2_k3
),
decltype
(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3
,
...
...
@@ -692,7 +694,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
Gemm
DataType
,
FloatGemmAcc
,
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
b_thread_desc_k0_n_k1
),
...
...
@@ -733,12 +735,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMLoop
=
Free1_M
/
Sum_M
;
static
constexpr
index_t
GemmMPack
=
math
::
max
(
A_M1
,
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
A_M1
,
MfmaSelector
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
B_M3
=
GemmMPack
;
// 8
static
constexpr
index_t
B_M2
=
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
GemmMPack
,
false
>
{}.
K0PerXdlops
;
// 2
static
constexpr
index_t
B_M1
=
Sum_M
/
B_M2
/
B_M3
;
// 4
static
constexpr
index_t
B_M0
=
GemmMLoop
;
// 2
XdlopsGemm
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
,
GemmMPack
,
false
>
{}.
K0PerXdlops
;
// 2
static
constexpr
index_t
B_M1
=
Sum_M
/
B_M2
/
B_M3
;
// 4
static
constexpr
index_t
B_M0
=
GemmMLoop
;
// 2
__host__
__device__
static
constexpr
auto
GetABlockSliceLengths_M0_N0_M1_N1_M2_N2
()
{
...
...
@@ -875,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
template
<
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
Gemm
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
ElementwiseOp
,
...
...
@@ -968,8 +970,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static
constexpr
auto
b_thread_desc_m0_o_m1
=
MakeBThreadDesc_M0_O_M1
();
using
BBlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
ThreadwiseTensorSliceTransfer_v2
<
Gemm
DataType
,
Gemm
DataType
,
decltype
(
b_block_desc_o0_o1_o2_m0_m1_m2_m3
),
decltype
(
b_thread_desc_o0_o1_o2_m0_m1_m2_m3
),
BThreadSlice_O0_O1_O2_M0_M1_M2_M3
,
...
...
@@ -985,7 +987,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
Gemm
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_m0_n_m1
),
decltype
(
b_thread_desc_m0_o_m1
),
...
...
@@ -1001,7 +1003,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
Gemm2Params_N_O_M
::
GemmMPack
,
true
,
// TransposeC
Gemm2Params_N_O_M
::
GemmMPack
*
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
Gemm2Params_N_O_M
::
GemmMPack
,
false
>
{}
XdlopsGemm
<
Gemm
DataType
,
MPerXdl
,
NPerXdl
,
Gemm2Params_N_O_M
::
GemmMPack
,
false
>
{}
.
K0PerXdlops
,
Gemm2Params_N_O_M
::
GemmMPack
>
;
...
...
@@ -1092,7 +1094,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
FloatGemmAcc
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
...
...
@@ -1165,7 +1167,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static
constexpr
auto
p_slash_sgrad_block_desc_m0_n_m1
=
GetA2BlockDescriptor_M0_N_M1
<
Gemm2Params_N_O_M
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{};
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
Gemm
DataType
)
>
{};
static
constexpr
auto
q_block_space_size_aligned
=
math
::
integer_least_multiple
(
q_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
...
...
@@ -1193,7 +1195,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static
constexpr
auto
reduction_space_offset
=
(
ygrad_block_space_size_aligned
.
value
+
q_block_space_size_aligned
.
value
)
*
sizeof
(
DataType
)
/
sizeof
(
FloatGemmAcc
);
sizeof
(
Gemm
DataType
)
/
sizeof
(
FloatGemmAcc
);
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -1206,14 +1208,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
{
const
index_t
k_bytes_end
=
(
SharedMemTrait
::
k_block_space_offset
+
SharedMemTrait
::
k_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
Gemm
DataType
);
const
index_t
v_bytes_end
=
(
SharedMemTrait
::
v_block_space_offset
+
SharedMemTrait
::
v_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
Gemm
DataType
);
const
index_t
p_slash_sgrad_bytes_end
=
(
SharedMemTrait
::
p_slash_sgrad_block_space_offset
+
SharedMemTrait
::
p_slash_sgrad_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
Gemm
DataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
...
...
@@ -1263,8 +1265,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const
float
p_drop
,
ck
::
philox
&
ph
)
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
ushort
p_dropout_in_16bits
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
65535.0
));
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
...
...
@@ -1315,19 +1317,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// LDS allocation for Q / K / V / dY
auto
q_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
q_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
q_block_space_offset
,
GemmBlockwiseCopy
::
q_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
k_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
k_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
k_block_space_offset
,
GemmBlockwiseCopy
::
k_block_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
v_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
v_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
v_block_space_offset
,
GemmBlockwiseCopy
::
v_block_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
ygrad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
ygrad_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
ygrad_block_space_offset
,
GemmBlockwiseCopy
::
ygrad_block_desc_k0_m_k1
.
GetElementSpaceSize
());
// Q matrix blockwise copy
...
...
@@ -1394,10 +1396,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
// Gemm1: VGPR allocation for A and B
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Gemm
DataType
>
(
Gemm1
::
a_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
gemm1_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
auto
gemm1_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Gemm
DataType
>
(
Gemm1
::
b_thread_desc_n0_n1_n2_k0_k1_k2_k3
.
GetElementSpaceSize
());
// dQ: transform input and output tensor descriptors
...
...
@@ -1589,10 +1591,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
p_slash_sgrad_block_space_offset
,
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
p_slash_sgrad_block_space_offset
,
Gemm2
::
a_block_desc_m0_n_m1
.
GetElementSpaceSize
());
auto
gemm2_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
auto
gemm2_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Gemm
DataType
>
(
Gemm2
::
b_thread_desc_o0_o1_o2_m0_m1_m2_m3
.
GetElementSpaceSize
());
// dV: transform input and output tensor descriptors
...
...
@@ -1722,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// performs for y
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
...
...
@@ -1735,8 +1737,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// performs for ygrad
auto
ygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
Gemm
DataType
,
FloatGemmAcc
,
decltype
(
YDotYGrad_M_O
::
ygrad_block_desc_m_o
),
decltype
(
ygrad_thread_desc_m_o
),
decltype
(
ygrad_thread_desc_m_o
.
GetLengths
()),
...
...
include/ck/utility/generic_memory_space_atomic.hpp
View file @
326e9ee8
...
...
@@ -71,6 +71,141 @@ __device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
return
vy
.
template
AsType
<
double2_t
>()[
I0
];
}
inline
__host__
__device__
half2_t
add_fp16x2_t
(
const
half2_t
&
a
,
const
half2_t
&
b
)
{
half2_t
rtn
;
rtn
[
0
]
=
a
[
0
]
+
b
[
0
];
rtn
[
1
]
=
a
[
1
]
+
b
[
1
];
return
rtn
;
}
union
U32FP162_ADDR
{
uint32_t
*
u32_a
;
half2_t
*
fp162_a
;
};
union
U32FP162
{
uint32_t
u32
;
half2_t
fp162
;
};
template
<
>
__device__
half2_t
atomic_add
<
half2_t
>
(
half2_t
*
p_dst
,
const
half2_t
&
x
)
{
U32FP162_ADDR
dword_addr
;
U32FP162
cur_v
;
U32FP162
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
fp162_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
fp162
=
add_fp16x2_t
(
cur_v
.
fp162
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
return
x
;
}
// template <>
// __device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// half2_t new_ = add_fp16x2_t(*reinterpret_cast<half2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// union U16BF16 {
// uint16_t u16;
// bhalf_t bf16;
// };
// inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b){
// U16BF16 xa {.bf16 = a};
// U16BF16 xb {.bf16 = b};
// U16BF16 xr;
// xr.u16 = xa.u16 + xb.u16;
// return xr.bf16;
// }
inline
__host__
__device__
bhalf_t
add_bf16_t
(
const
bhalf_t
&
a
,
const
bhalf_t
&
b
)
{
return
type_convert
<
bhalf_t
>
(
type_convert
<
float
>
(
a
)
+
type_convert
<
float
>
(
b
));
}
inline
__host__
__device__
bhalf2_t
add_bf16x2_t
(
const
bhalf2_t
&
a
,
const
bhalf2_t
&
b
)
{
bhalf2_t
rtn
;
rtn
[
0
]
=
add_bf16_t
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add_bf16_t
(
a
[
1
],
b
[
1
]);
return
rtn
;
}
union
U32BF162_ADDR
{
uint32_t
*
u32_a
;
bhalf2_t
*
bf162_a
;
};
union
U32BF162
{
uint32_t
u32
;
bhalf2_t
bf162
;
};
template
<
>
__device__
bhalf2_t
atomic_add
<
bhalf2_t
>
(
bhalf2_t
*
p_dst
,
const
bhalf2_t
&
x
)
{
U32BF162_ADDR
dword_addr
;
U32BF162
cur_v
;
U32BF162
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
bf162_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
bf162
=
add_bf16x2_t
(
cur_v
.
bf162
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
return
x
;
}
// template <>
// __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// bhalf2_t new_ = add_bf16x2_t(*reinterpret_cast<bhalf2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment