Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
79cf90f2
Commit
79cf90f2
authored
Aug 16, 2023
by
letaoqin
Browse files
add code to device
parent
72a345c6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
126 additions
and
56 deletions
+126
-56
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+47
-48
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+46
-8
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+33
-0
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
79cf90f2
...
...
@@ -70,8 +70,7 @@ using AccDataType = F32;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
// INT32
using
DDataType
=
F16
;
using
Acc0BiasDataType
=
DDataType
;
using
Acc0BiasDataType
=
F16
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
...
...
@@ -414,35 +413,35 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
D
DataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0Bias
DataType
>
{
-
2
,
2
});
break
;
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
D
DataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Acc0Bias
DataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D
DataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0Bias
DataType
>
{
1
});
break
;
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D
DataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0Bias
DataType
>
{
1
});
break
;
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D
DataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0Bias
DataType
>
{
1
});
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
...
...
@@ -450,7 +449,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D
DataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0Bias
DataType
>
{
1
});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -464,7 +463,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D
DataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0Bias
DataType
>
{
1
});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -477,7 +476,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
D
DataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
Acc0Bias
DataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
...
...
@@ -498,7 +497,7 @@ int run(int argc, char* argv[])
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
InputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
D
DataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
Acc0Bias
DataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
InputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -529,8 +528,8 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
D
DataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_biases;
nullptr
,
// p_acc1_biases;
static_cast
<
Acc0Bias
DataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_biases;
nullptr
,
// p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
...
...
@@ -563,41 +562,41 @@ int run(int argc, char* argv[])
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
}
// not need output z matrix
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
InputDataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
InputDataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
D
DataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_biases;
nullptr
,
// p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
d_gs_ms_ns_lengths
,
// acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths,
{},
// acc1_biases_gs_ms_os_strides,
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
InputDataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
InputDataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Acc0Bias
DataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_biases;
nullptr
,
// p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
d_gs_ms_ns_lengths
,
// acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths,
{},
// acc1_biases_gs_ms_os_strides,
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
qgrad_device_buf
.
SetZero
();
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
79cf90f2
...
...
@@ -26,6 +26,7 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
InputDataType
,
typename
D0DataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
LSEDataType
,
...
...
@@ -36,6 +37,7 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
...
...
@@ -54,6 +56,7 @@ __global__ void
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2
(
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
const
D0DataType
*
__restrict__
p_d0_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_b1_grid
,
const
InputDataType
*
__restrict__
p_c_grid
,
...
...
@@ -69,6 +72,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
...
...
@@ -100,6 +104,8 @@ __global__ void
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
...
...
@@ -114,6 +120,9 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
ignore
=
p_d0_grid
;
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
d0_batch_offset
;
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
...
...
@@ -188,6 +197,7 @@ __global__ void
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_d0_grid
;
ignore
=
p_z_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
...
...
@@ -203,6 +213,7 @@ __global__ void
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -598,14 +609,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
()
{}
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
D0GridDesc_G_M_N
&
d0_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
d0_grid_desc_g_m_n_
(
d0_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
...
...
@@ -623,6 +637,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetD0BasePtr
(
index_t
g_idx
)
const
{
return
d0_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
...
...
@@ -646,6 +664,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
...
...
@@ -656,6 +675,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<
InputDataType
,
// TODO: distinguish A/B datatype
D0DataType
,
OutputDataType
,
ZDataType
,
GemmDataType
,
...
...
@@ -671,6 +691,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
KGridDesc_N_K
,
D0GridDesc_M_N
,
ZGridDesc_M_N
,
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
...
...
@@ -819,13 +840,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())},
p_drop_
{
p_drop
}
{
// TODO: implement bias addition
...
...
@@ -846,6 +860,26 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
y_grid_desc_m_o_
);
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
auto
d0_grid_desc_m_n
=
MakeDGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
}
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
d0_grid_desc_g_m_n_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
...
@@ -898,7 +932,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
D0GridDesc_M_N
d0_grid_desc_m_n_
;
typename
GridwiseGemm
::
D0GridDesc
riptor
_M
0
_N
0_M1_M2_N1_M3
d0_grid_desc_m
0
_n
0_m1_m2_n1_m3
_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
...
...
@@ -978,6 +1012,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2
<
GridwiseGemm
,
InputDataType
,
D0DataType
,
OutputDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -988,6 +1023,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
...
...
@@ -1008,6 +1044,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_d0_grid_
,
arg
.
p_z_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -1023,6 +1060,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
79cf90f2
...
...
@@ -21,6 +21,7 @@
namespace
ck
{
template
<
typename
InputDataType
,
typename
D0DataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
GemmDataType
,
...
...
@@ -36,6 +37,7 @@ template <typename InputDataType,
typename
QGridDesc_K0_M_K1
,
typename
KGridDesc_K0_N_K1
,
typename
KGridDesc_N_K
,
typename
D0GridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
VGridDesc_N0_O_N1
,
typename
YGridDesc_M_O
,
...
...
@@ -120,6 +122,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
// D0
static
constexpr
auto
D0M3
=
Number
<
2
>
{};
static
constexpr
auto
D0M2
=
Number
<
MPerXdl
/
D0M3
.
value
>
{};
static
constexpr
auto
D0M1
=
Number
<
MPerBlock
/
MPerXdl
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
...
...
@@ -1153,6 +1160,32 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_block_bytes_end
);
}
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
D0M1
,
D0M2
,
D0M3
)),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
3
,
5
>
{},
Sequence
<
1
,
4
>
{}));
return
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
}
struct
D0
{
};
using
D0GridDescriptor_M0_N0_M1_M2_N1_M3
=
remove_cvref_t
<
decltype
(
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
D0GridDesc_M_N
{}))
>
;
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment