Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9dc05eaa
Commit
9dc05eaa
authored
Feb 12, 2023
by
ltqin
Browse files
change fp16 xdl to bf16
parent
c1ed00b6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
29 deletions
+81
-29
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+45
-29
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+36
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
9dc05eaa
...
...
@@ -85,6 +85,21 @@ template <typename DataType,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
template
<
typename
T
>
struct
TypeMap
{
using
type
=
T
;
};
#if defined(__gfx90a__)
template
<
>
struct
TypeMap
<
ck
::
half_t
>
{
using
type
=
ck
::
bhalf_t
;
};
#endif
using
LDSDataType
=
typename
TypeMap
<
DataType
>::
type
;
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
...
...
@@ -126,7 +141,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
LDS
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
...
...
@@ -142,7 +157,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
{
constexpr
auto
mfma
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
LDS
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
...
...
@@ -456,7 +471,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
LDS
DataType
,
GridDesc_K0_M_K1
,
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -481,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
LDS
DataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -496,13 +511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
LDSDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
LDS
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -564,7 +580,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
DataType
,
LDS
DataType
,
decltype
(
a_src_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -583,7 +599,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
LDS
DataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
...
...
@@ -614,11 +630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static
constexpr
index_t
GemmKPack
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
LDS
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
LDS
DataType
,
FloatGemmAcc
,
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -634,7 +650,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
GemmKPack
,
true
,
// TransposeC
GemmKPack
,
// AMmaKStride
GemmKPack
*
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}
GemmKPack
*
XdlopsGemm
<
LDS
DataType
,
MPerXdl
,
NPerXdl
,
GemmKPack
,
false
>
{}
.
K0PerXdlops
/* BMmaKStride */
>
;
};
...
...
@@ -666,7 +682,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMPack
=
math
::
max
(
math
::
lcm
(
A_M1
,
B_M1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
LDS
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
BBlockSliceLengths
=
Sequence
<
B_M0
,
Free1_O
,
B_M1
>
;
using
BThreadClusterLengths
=
...
...
@@ -791,7 +807,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template
<
typename
ElementwiseOp
=
tensor_operation
::
element_wise
::
PassThrough
>
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
LDS
DataType
,
decltype
(
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
ElementwiseOp
,
...
...
@@ -821,7 +837,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
Gemm2Params_N_O_M
::
BThreadClusterLengths
,
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
DataType
,
DataType
,
LDS
DataType
,
GridDesc_M0_O_M1
,
decltype
(
b_block_desc_m0_o_m1
),
typename
Gemm2Params_N_O_M
::
BThreadClusterArrangeOrder
,
// access order == thread order
...
...
@@ -838,7 +854,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
LDS
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_m0_n_m1
),
decltype
(
b_block_desc_m0_o_m1
),
...
...
@@ -905,7 +921,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
DataType
);
static
constexpr
index_t
SrcScalarPerVector
=
16
/
sizeof
(
FloatGemmAcc
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVector
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
...
...
@@ -917,7 +933,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
FloatGemmAcc
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
...
...
@@ -1079,7 +1095,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static
constexpr
auto
b2_block_desc_m0_o_m1
=
GetB2BlockDescriptor_M0_O_M1
<
Gemm2Params_N_O_M
>
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{};
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
LDS
DataType
)
>
{};
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
...
...
@@ -1115,13 +1131,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
LDS
DataType
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
LDS
DataType
);
const
index_t
vgrad_gemm_bytes_end
=
(
SharedMemTrait
::
p_block_space_size_aligned
+
SharedMemTrait
::
ygrad_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
LDS
DataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
...
...
@@ -1224,11 +1240,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
gemm0_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
Gemm0
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// Gemm0: gridwise GEMM pipeline
...
...
@@ -1320,11 +1336,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
// Gemm1: VGPR allocation for A and LDS allocation for B
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
LDS
DataType
>
(
Gemm1
::
a_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
gemm1_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
Gemm1
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// dQ: transform input and output tensor descriptors
...
...
@@ -1516,11 +1532,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm2: LDS allocation for A and B: be careful of alignment
auto
gemm2_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a2_block_space_offset
,
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a2_block_space_offset
,
Gemm2
::
a_block_desc_m0_n_m1
.
GetElementSpaceSize
());
auto
gemm2_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b2_block_space_offset
,
static_cast
<
LDS
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b2_block_space_offset
,
Gemm2
::
b_block_desc_m0_o_m1
.
GetElementSpaceSize
());
// dV: transform input and output tensor descriptors
...
...
@@ -1640,7 +1656,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
FloatGemmAcc
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
...
...
include/ck/utility/data_type.hpp
View file @
9dc05eaa
...
...
@@ -1010,6 +1010,42 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert fp16 to bf16
template
<
>
inline
__host__
__device__
bhalf_t
type_convert
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
static_cast
<
float
>
(
x
)};
return
uint16_t
(
u
.
int32
>>
16
);
}
template
<
>
inline
__host__
__device__
bhalf2_t
type_convert
<
bhalf2_t
,
half2_t
>
(
half2_t
x
)
{
float
y0
{
0
},
y1
{
0
};
bhalf2_t
y
{
0
};
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1
\n
\
"
:
"=v"
(
y0
)
:
"v"
(
x
));
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1
\n
\
"
:
"=v"
(
y1
)
:
"v"
(
x
));
asm
volatile
(
"
\n
\
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1]
\n
\
"
:
"=v"
(
y
)
:
"v"
(
y0
),
"v"
(
y1
));
return
y
;
}
template
<
typename
T
>
struct
NumericLimits
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment