Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
21cec2bb
"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "3eecbfb6ec231cd8012faceb8b6fbc87199db60d"
Commit
21cec2bb
authored
Aug 29, 2023
by
letaoqin
Browse files
change biases to bias in batched mha
parent
ff6d9e1f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
77 additions
and
77 deletions
+77
-77
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+40
-40
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+37
-37
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
21cec2bb
...
@@ -750,8 +750,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -750,8 +750,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
es
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
es
,
const
D1DataType
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -763,12 +763,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -763,12 +763,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
...
@@ -778,7 +778,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -778,7 +778,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_d0_grid_
{
p_acc0_bias
es
},
p_d0_grid_
{
p_acc0_bias
},
p_z_grid_
{
p_z_grid
},
p_z_grid_
{
p_z_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
...
@@ -836,12 +836,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -836,12 +836,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_drop_
{
p_drop
}
p_drop_
{
p_drop
}
{
{
// TODO: implement bias addition
// TODO: implement bias addition
ignore
=
p_acc0_bias
es
;
ignore
=
p_acc0_bias
;
ignore
=
p_acc1_bias
es
;
ignore
=
p_acc1_bias
;
ignore
=
acc0_bias
es
_gs_ms_ns_lengths
;
ignore
=
acc0_bias_gs_ms_ns_lengths
;
ignore
=
acc0_bias
es
_gs_ms_ns_strides
;
ignore
=
acc0_bias_gs_ms_ns_strides
;
ignore
=
acc1_bias
es
_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias
es
_gs_ms_gemm1ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
...
@@ -854,13 +854,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -854,13 +854,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
auto
d0_grid_desc_m_n
=
MakeDGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
,
const
auto
d0_grid_desc_m_n
=
acc0_biase
s_gs_ms_ns_strides
);
MakeDGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bia
s_gs_ms_ns_strides
);
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias
es
_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
}
}
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
...
@@ -1176,8 +1176,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1176,8 +1176,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
es
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
es
,
const
D1DataType
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1189,12 +1189,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1189,12 +1189,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
...
@@ -1213,8 +1213,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1213,8 +1213,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_qgrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_kgrad_grid
,
p_vgrad_grid
,
p_vgrad_grid
,
p_acc0_bias
es
,
p_acc0_bias
,
p_acc1_bias
es
,
p_acc1_bias
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
@@ -1226,10 +1226,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1226,10 +1226,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
acc0_bias
es
_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
acc_element_op
,
...
@@ -1254,8 +1254,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1254,8 +1254,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void
*
p_qgrad_grid
,
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
void
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
es
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
es
,
const
D1DataType
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1267,12 +1267,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1267,12 +1267,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
...
@@ -1292,8 +1292,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1292,8 +1292,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast
<
OutputDataType
*>
(
p_qgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_qgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_kgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_kgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
es
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
es
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
@@ -1305,10 +1305,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1305,10 +1305,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
acc0_bias
es
_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
acc1_bias_gs_ms_gemm1ns_lengths
,
acc1_bias
es
_gs_ms_gemm1ns_strides
,
acc1_bias_gs_ms_gemm1ns_strides
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
acc_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
21cec2bb
...
@@ -766,8 +766,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -766,8 +766,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
es
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
es
,
const
D1DataType
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -779,12 +779,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -779,12 +779,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
...
@@ -794,7 +794,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -794,7 +794,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_d0_grid_
{
p_acc0_bias
es
},
p_d0_grid_
{
p_acc0_bias
},
p_z_grid_
{
p_z_grid
},
p_z_grid_
{
p_z_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
...
@@ -851,9 +851,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -851,9 +851,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_drop_
{
p_drop
}
p_drop_
{
p_drop
}
{
{
// TODO: implement bias addition
// TODO: implement bias addition
ignore
=
p_acc1_bias
es
;
ignore
=
p_acc1_bias
;
ignore
=
acc1_bias
es
_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_bias
es
_gs_ms_gemm1ns_strides
;
ignore
=
acc1_bias_gs_ms_gemm1ns_strides
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
...
@@ -867,13 +867,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -867,13 +867,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
auto
d0_grid_desc_m_n
=
MakeDGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
,
const
auto
d0_grid_desc_m_n
=
acc0_biase
s_gs_ms_ns_strides
);
MakeDGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bia
s_gs_ms_ns_strides
);
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
d0_grid_desc_m0_n0_m1_m2_n1_m3_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias
es
_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
}
}
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
...
@@ -1209,8 +1209,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1209,8 +1209,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
es
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
es
,
const
D1DataType
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1222,12 +1222,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1222,12 +1222,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
...
@@ -1246,8 +1246,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1246,8 +1246,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_qgrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_kgrad_grid
,
p_vgrad_grid
,
p_vgrad_grid
,
p_acc0_bias
es
,
p_acc0_bias
,
p_acc1_bias
es
,
p_acc1_bias
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
@@ -1259,10 +1259,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1259,10 +1259,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
acc0_bias
es
_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
acc_element_op
,
...
@@ -1287,8 +1287,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1287,8 +1287,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void
*
p_qgrad_grid
,
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
void
*
p_vgrad_grid
,
const
void
*
p_acc0_bias
es
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc1_bias
es
,
const
void
*
p_acc1_bias
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1300,12 +1300,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1300,12 +1300,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
// acc1_bias
es
_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
acc1_bias
es
_gs_ms_gemm1ns_strides
,
// acc1_bias
es
_gs_ms_os_strides
acc1_bias_gs_ms_gemm1ns_strides
,
// acc1_bias_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
AccElementwiseOperation
acc_element_op
,
...
@@ -1325,8 +1325,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1325,8 +1325,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast
<
OutputDataType
*>
(
p_qgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_qgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_kgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_kgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
es
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
es
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_lengths
,
...
@@ -1338,10 +1338,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1338,10 +1338,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
acc0_bias
es
_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias
es
_gs_ms_gemm1ns_lengths
,
acc1_bias_gs_ms_gemm1ns_lengths
,
acc1_bias
es
_gs_ms_gemm1ns_strides
,
acc1_bias_gs_ms_gemm1ns_strides
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
acc_element_op
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment