Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
20fc6679
Commit
20fc6679
authored
Feb 23, 2023
by
danyao12
Browse files
attn bwd prototype1 bf16
parent
010ed35f
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
4669 additions
and
9 deletions
+4669
-9
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+1
-0
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_bf16.cpp
...ax_gemm/batched_multihead_attention_backward_pt1_bf16.cpp
+884
-0
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
...ax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
+77
-0
include/ck/ck.hpp
include/ck/ck.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+2
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
..._multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
+1258
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
..._multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
+2369
-0
include/ck/utility/generic_memory_space_atomic.hpp
include/ck/utility/generic_memory_space_atomic.hpp
+69
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
20fc6679
...
...
@@ -11,6 +11,7 @@ add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_
add_example_executable
(
example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_pt1_bf16 batched_multihead_attention_backward_pt1_bf16.cpp
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
20fc6679
...
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define PRINT_HOST 0
#define USING_MASK
1
#define USING_MASK
0
#include <iostream>
#include <numeric>
...
...
@@ -268,8 +268,8 @@ int run(int argc, char* argv[])
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
128
;
ck
::
index_t
O
=
128
;
ck
::
index_t
G0
=
3
;
ck
::
index_t
G1
=
2
;
ck
::
index_t
G0
=
54
;
ck
::
index_t
G1
=
16
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_bf16.cpp
0 → 100644
View file @
20fc6679
This diff is collapsed.
Click to expand it.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
View file @
20fc6679
...
...
@@ -25,6 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_HD32 0
#include <iostream>
#include <numeric>
...
...
@@ -86,6 +87,80 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// Headdim/K/O should be a multiple of 8, and it's only supported up to 64 in prototype1.
// If Headdim/K/O <= 32, ues 1st template.
// If 32 < Headdim/K/O <= 64, ues 2nd template.
#if USING_HD32
// 1st template
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
32
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
1
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#else
//2nd template
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
NumDimG
,
...
...
@@ -125,6 +200,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -151,6 +227,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
...
...
include/ck/ck.hpp
View file @
20fc6679
...
...
@@ -118,7 +118,7 @@
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
0
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
// experimental feature: in-regsiter sub-dword transpose
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
20fc6679
...
...
@@ -204,6 +204,7 @@ template <index_t NumDimG,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -627,6 +628,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm2NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
0 → 100644
View file @
20fc6679
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
20fc6679
...
...
@@ -51,6 +51,7 @@ template <typename DataType,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -726,9 +727,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static_assert
(
Sum_M
%
MPerXdl
==
0
,
""
);
static
constexpr
index_t
GemmNWave
=
2
;
static
constexpr
index_t
GemmNWave
=
Free0_N
/
Gemm2NXdlPerWave
/
MPerXdl
;
static
constexpr
index_t
GemmOWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Free0_N
/
GemmNWave
/
MPerXdl
;
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMLoop
=
Free1_M
/
Sum_M
;
static
constexpr
index_t
GemmMPack
=
...
...
@@ -1563,8 +1564,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
1
,
// DstScalarPerVector
9
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
...
@@ -1937,7 +1938,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
else
{
s_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
}
});
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1_bf16v1.hpp
0 → 100644
View file @
20fc6679
This diff is collapsed.
Click to expand it.
include/ck/utility/generic_memory_space_atomic.hpp
View file @
20fc6679
...
...
@@ -71,6 +71,75 @@ __device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
return
vy
.
template
AsType
<
double2_t
>()[
I0
];
}
inline
__host__
__device__
half2_t
add_fp16x2_t
(
const
half2_t
&
a
,
const
half2_t
&
b
)
{
half2_t
rtn
;
rtn
[
0
]
=
a
[
0
]
+
b
[
0
];
rtn
[
1
]
=
a
[
1
]
+
b
[
1
];
return
rtn
;
}
template
<
>
__device__
half2_t
atomic_add
<
half2_t
>
(
half2_t
*
p_dst
,
const
half2_t
&
x
)
{
uint32_t
*
dword_addr
=
reinterpret_cast
<
uint32_t
*>
(
p_dst
);
uint32_t
cur_v
=
*
dword_addr
;
uint32_t
old_v
,
new_v
;
do
{
old_v
=
cur_v
;
half2_t
new_
=
add_fp16x2_t
(
*
reinterpret_cast
<
half2_t
*>
(
&
cur_v
),
x
);
new_v
=
*
reinterpret_cast
<
uint32_t
*>
(
&
new_
);
cur_v
=
atomicCAS
(
dword_addr
,
old_v
,
new_v
);
}
while
(
cur_v
!=
old_v
);
return
x
;
}
// union U16BF16 {
// uint16_t u16;
// bhalf_t bf16;
// };
// inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b){
// U16BF16 xa {.bf16 = a};
// U16BF16 xb {.bf16 = b};
// U16BF16 xr;
// xr.u16 = xa.u16 + xb.u16;
// return xr.bf16;
// }
inline
__host__
__device__
bhalf_t
add_bf16_t
(
const
bhalf_t
&
a
,
const
bhalf_t
&
b
)
{
return
type_convert
<
bhalf_t
>
(
type_convert
<
float
>
(
a
)
+
type_convert
<
float
>
(
b
));
}
inline
__host__
__device__
bhalf2_t
add_bf16x2_t
(
const
bhalf2_t
&
a
,
const
bhalf2_t
&
b
)
{
bhalf2_t
rtn
;
rtn
[
0
]
=
add_bf16_t
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add_bf16_t
(
a
[
1
],
b
[
1
]);
return
rtn
;
}
template
<
>
__device__
bhalf2_t
atomic_add
<
bhalf2_t
>
(
bhalf2_t
*
p_dst
,
const
bhalf2_t
&
x
)
{
uint32_t
*
dword_addr
=
reinterpret_cast
<
uint32_t
*>
(
p_dst
);
uint32_t
cur_v
=
*
dword_addr
;
uint32_t
old_v
,
new_v
;
do
{
old_v
=
cur_v
;
bhalf2_t
new_
=
add_bf16x2_t
(
*
reinterpret_cast
<
bhalf2_t
*>
(
&
cur_v
),
x
);
new_v
=
*
reinterpret_cast
<
uint32_t
*>
(
&
new_
);
cur_v
=
atomicCAS
(
dword_addr
,
old_v
,
new_v
);
}
while
(
cur_v
!=
old_v
);
return
x
;
}
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment