Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d10f25a0
Commit
d10f25a0
authored
Aug 29, 2023
by
letaoqin
Browse files
bwd biaes to bias
parent
127982f1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
86 deletions
+84
-86
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+12
-12
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+35
-36
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+35
-36
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
d10f25a0
...
@@ -515,8 +515,8 @@ int run(int argc, char* argv[])
...
@@ -515,8 +515,8 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Acc0BiasDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias
es
;
static_cast
<
Acc0BiasDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias;
nullptr
,
// p_acc1_bias
es
;
nullptr
,
// p_acc1_bias;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
...
@@ -528,10 +528,10 @@ int run(int argc, char* argv[])
...
@@ -528,10 +528,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
d_gs_ms_ns_lengths
,
// acc0_bias
es
_gs_ms_ns_lengths
d_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_bias
es
_gs_ms_ns_strides
d_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
{},
// acc1_bias
es
_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias
es
_gs_ms_os_strides,
{},
// acc1_bias_gs_ms_os_strides,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
Scale
{
alpha
},
...
@@ -560,8 +560,8 @@ int run(int argc, char* argv[])
...
@@ -560,8 +560,8 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Acc0BiasDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias
es
;
static_cast
<
Acc0BiasDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_bias;
nullptr
,
// p_acc1_bias
es
;
nullptr
,
// p_acc1_bias;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_lengths
,
...
@@ -573,10 +573,10 @@ int run(int argc, char* argv[])
...
@@ -573,10 +573,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
d_gs_ms_ns_lengths
,
// acc0_bias
es
_gs_ms_ns_lengths
d_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_bias
es
_gs_ms_ns_strides
d_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
{},
// acc1_bias
es
_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias
es
_gs_ms_os_strides,
{},
// acc1_bias_gs_ms_os_strides,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
Scale
{
alpha
},
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
d10f25a0
...
@@ -437,8 +437,8 @@ int run(int argc, char* argv[])
...
@@ -437,8 +437,8 @@ int run(int argc, char* argv[])
lse_gs_ms_strides
,
lse_gs_ms_strides
,
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_strides
,
d0_gs_ms_ns_strides
,
{},
//
std::array<std::vector<ck::index_t>, 1>{
acc1_bias
es
_gs_ms_os_lengths
}
,
{},
// acc1_bias_gs_ms_os_lengths,
{},
//
std::array<std::vector<ck::index_t>, 1>{
acc1_bias
es
_gs_ms_os_strides
}
,
{},
// acc1_bias_gs_ms_os_strides,
});
});
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
d10f25a0
...
@@ -299,11 +299,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -299,11 +299,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
acc0_bias
es
_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_bias
es
_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
acc1_bias
es
_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
acc1_bias_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
acc1_bias
es
_gs_ms_os_strides
;
std
::
vector
<
index_t
>
acc1_bias_gs_ms_os_strides
;
};
};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -497,22 +497,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -497,22 +497,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
}
}
}
// D in Gemm0 C position
// D in Gemm0 C position
static
auto
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias
es
_gs_ms_ns_lengths
,
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
static
auto
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_lengths
,
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias
es
_gs_ms_ns_lengths
,
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
...
@@ -756,8 +755,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -756,8 +755,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -788,9 +787,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -788,9 +787,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Kgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Kgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Vgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Vgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
es
.
size
())
||
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
es
.
size
()
==
0
))
&&
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
.
size
()
==
0
))
&&
0
==
p_acc1_bias
es
.
size
()))
0
==
p_acc1_bias
.
size
()))
{
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
}
...
@@ -804,8 +803,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -804,8 +803,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
InputDataType
*>
(
p_Bs
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
InputDataType
*>
(
p_Bs
[
i
]);
const
auto
p_d0_grid
=
const
auto
p_d0_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
es
.
size
())
==
group_count_
)
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
.
size
())
==
group_count_
)
?
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
es
[
i
])
?
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
[
i
])
:
nullptr
;
:
nullptr
;
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
InputDataType
*>
(
p_B1s
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
InputDataType
*>
(
p_B1s
[
i
]);
...
@@ -827,8 +826,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -827,8 +826,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
tmp_d0_gs_ms_ns_lengths
=
problem_desc
.
acc0_bias
es
_gs_ms_ns_lengths
;
tmp_d0_gs_ms_ns_lengths
=
problem_desc
.
acc0_bias_gs_ms_ns_lengths
;
tmp_d0_gs_ms_ns_strides
=
problem_desc
.
acc0_bias
es
_gs_ms_ns_strides
;
tmp_d0_gs_ms_ns_strides
=
problem_desc
.
acc0_bias_gs_ms_ns_strides
;
}
}
else
else
{
{
...
@@ -971,12 +970,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -971,12 +970,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_n_length_stride
});
d0_n_length_stride
});
}
}
// TODO: implement bias addition
// TODO: implement bias addition
// ignore = p_acc0_bias
es
;
// ignore = p_acc0_bias;
// ignore = p_acc1_bias
es
;
// ignore = p_acc1_bias;
// ignore = acc0_bias
es
_gs_ms_ns_lengths;
// ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_bias
es
_gs_ms_ns_strides;
// ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_bias
es
_gs_ms_gemm1ns_lengths;
// ignore = acc1_bias_gs_ms_gemm1ns_lengths;
// ignore = acc1_bias
es
_gs_ms_gemm1ns_strides;
// ignore = acc1_bias_gs_ms_gemm1ns_strides;
}
}
// element-wise op
// element-wise op
...
@@ -1197,8 +1196,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1197,8 +1196,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1218,8 +1217,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1218,8 +1217,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads
,
p_Qgrads
,
p_Kgrads
,
p_Kgrads
,
p_Vgrads
,
p_Vgrads
,
p_acc0_bias
es
,
p_acc0_bias
,
p_acc1_bias
es
,
p_acc1_bias
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -1245,8 +1244,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1245,8 +1244,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1266,8 +1265,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1266,8 +1265,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads
,
p_Qgrads
,
p_Kgrads
,
p_Kgrads
,
p_Vgrads
,
p_Vgrads
,
p_acc0_bias
es
,
// cast in struct Argument
p_acc0_bias
,
// cast in struct Argument
p_acc1_bias
es
,
// cast in struct Argument
p_acc1_bias
,
// cast in struct Argument
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
d10f25a0
...
@@ -306,11 +306,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -306,11 +306,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
acc0_bias
es
_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_bias
es
_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
acc1_bias
es
_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
acc1_bias_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
acc1_bias
es
_gs_ms_os_strides
;
std
::
vector
<
index_t
>
acc1_bias_gs_ms_os_strides
;
};
};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -497,22 +497,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -497,22 +497,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
}
}
}
// D in Gemm0 C position
// D in Gemm0 C position
static
auto
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias
es
_gs_ms_ns_lengths
,
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
static
auto
static
auto
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_lengths
,
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias
es
_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias
es
_gs_ms_ns_lengths
,
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias
es
_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
...
@@ -764,8 +763,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -764,8 +763,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -796,9 +795,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -796,9 +795,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Kgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Kgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Vgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Vgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
es
.
size
())
||
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
es
.
size
()
==
0
))
&&
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
.
size
()
==
0
))
&&
0
==
p_acc1_bias
es
.
size
()))
0
==
p_acc1_bias
.
size
()))
{
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
}
...
@@ -812,8 +811,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -812,8 +811,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
InputDataType
*>
(
p_Bs
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
InputDataType
*>
(
p_Bs
[
i
]);
const
auto
p_d0_grid
=
const
auto
p_d0_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
es
.
size
())
==
group_count_
)
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_bias
.
size
())
==
group_count_
)
?
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
es
[
i
])
?
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
[
i
])
:
nullptr
;
:
nullptr
;
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
InputDataType
*>
(
p_B1s
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
InputDataType
*>
(
p_B1s
[
i
]);
...
@@ -835,8 +834,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -835,8 +834,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
tmp_d0_gs_ms_ns_lengths
=
problem_desc
.
acc0_bias
es
_gs_ms_ns_lengths
;
tmp_d0_gs_ms_ns_lengths
=
problem_desc
.
acc0_bias_gs_ms_ns_lengths
;
tmp_d0_gs_ms_ns_strides
=
problem_desc
.
acc0_bias
es
_gs_ms_ns_strides
;
tmp_d0_gs_ms_ns_strides
=
problem_desc
.
acc0_bias_gs_ms_ns_strides
;
}
}
else
else
{
{
...
@@ -979,12 +978,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -979,12 +978,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_n_length_stride
});
d0_n_length_stride
});
}
}
// TODO: implement bias addition
// TODO: implement bias addition
// ignore = p_acc0_bias
es
;
// ignore = p_acc0_bias;
// ignore = p_acc1_bias
es
;
// ignore = p_acc1_bias;
// ignore = acc0_bias
es
_gs_ms_ns_lengths;
// ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_bias
es
_gs_ms_ns_strides;
// ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_bias
es
_gs_ms_gemm1ns_lengths;
// ignore = acc1_bias_gs_ms_gemm1ns_lengths;
// ignore = acc1_bias
es
_gs_ms_gemm1ns_strides;
// ignore = acc1_bias_gs_ms_gemm1ns_strides;
}
}
// element-wise op
// element-wise op
...
@@ -1209,8 +1208,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1209,8 +1208,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1230,8 +1229,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1230,8 +1229,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads
,
p_Qgrads
,
p_Kgrads
,
p_Kgrads
,
p_Vgrads
,
p_Vgrads
,
p_acc0_bias
es
,
p_acc0_bias
,
p_acc1_bias
es
,
p_acc1_bias
,
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -1257,8 +1256,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1257,8 +1256,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc0_bias
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
es
,
const
std
::
vector
<
const
void
*>&
p_acc1_bias
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1278,8 +1277,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1278,8 +1277,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads
,
p_Qgrads
,
p_Kgrads
,
p_Kgrads
,
p_Vgrads
,
p_Vgrads
,
p_acc0_bias
es
,
// cast in struct Argument
p_acc0_bias
,
// cast in struct Argument
p_acc1_bias
es
,
// cast in struct Argument
p_acc1_bias
,
// cast in struct Argument
problem_desc_vec
,
problem_desc_vec
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment