Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e0d6326b
Commit
e0d6326b
authored
Aug 14, 2023
by
letaoqin
Browse files
change interface
parent
b60595f9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
277 additions
and
280 deletions
+277
-280
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+43
-46
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+117
-117
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+117
-117
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
e0d6326b
...
...
@@ -71,8 +71,8 @@ using ShuffleDataType = F32;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
// INT32
using
DDataType
=
F16
;
using
Acc0BiasDataType
=
ck
::
Tuple
<
DDataType
>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
DDataType
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
...
...
@@ -529,9 +529,8 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
std
::
array
<
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
static_cast
<
DDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
// p_acc0_biases;
nullptr
,
// p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
...
...
@@ -543,12 +542,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
d_gs_ms_ns_lengths
,
// acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths,
{},
// acc1_biases_gs_ms_os_strides,
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
...
...
@@ -566,8 +563,8 @@ int run(int argc, char* argv[])
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
}
// not need output z matrix
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
InputDataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
InputDataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
InputDataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
...
...
@@ -577,8 +574,8 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
std
::
array
<
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()
}
,
//
std::array<void*, 1>
p_acc0_biases;
{},
//
std::array<void*, 1>
p_acc1_biases;
static_cast
<
DDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()
)
,
// p_acc0_biases;
nullptr
,
// p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
...
...
@@ -590,10 +587,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_lengths
}
,
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_strides
}
,
// acc0_biases_gs_ms_ns_strides
{},
// std::array<std::vector<ck::index_t>, 1>{
acc1_biases_gs_ms_os_lengths
}
,
{},
// std::array<std::vector<ck::index_t>, 1>{
acc1_biases_gs_ms_os_strides
}
,
d_gs_ms_ns_lengths
,
// acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_biases_gs_ms_ns_strides
{},
//
acc1_biases_gs_ms_os_lengths,
{},
//
acc1_biases_gs_ms_os_strides,
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
e0d6326b
...
...
@@ -291,11 +291,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
()
;
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
()
;
using
D0DataType
=
Acc0BiasDataType
;
using
D1DataType
=
Acc1BiasDataType
;
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"
Bias addition is unimplemented"
);
static_assert
(
std
::
is_void
<
D1DataType
>::
value
,
"Acc1
Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
;
...
...
@@ -702,8 +702,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InputDataType
*
p_a_grid
,
Argument
(
const
InputDataType
*
p_a_grid
,
const
InputDataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
const
InputDataType
*
p_b1_grid
,
...
...
@@ -713,8 +712,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
D0DataType
*
p_acc0_biases
,
const
D1DataType
*
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -726,11 +725,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1108,8 +1107,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InputDataType
*
p_a
,
static
auto
MakeArgument
(
const
InputDataType
*
p_a
,
const
InputDataType
*
p_b
,
ZDataType
*
p_z
,
const
InputDataType
*
p_b1
,
...
...
@@ -1119,8 +1118,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
D0DataType
*
p_acc0_biases
,
const
D1DataType
*
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1132,11 +1131,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1197,8 +1196,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
D0DataType
*
p_acc0_biases
,
const
D1DataType
*
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1210,11 +1209,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1224,7 +1223,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_a
),
static_cast
<
const
InputDataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
InputDataType
*>
(
p_b1
),
...
...
@@ -1234,8 +1234,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast
<
OutputDataType
*>
(
p_qgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_kgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_biases
)
,
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_biases
)
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
e0d6326b
...
...
@@ -299,11 +299,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
()
;
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
()
;
using
D0DataType
=
Acc0BiasDataType
;
using
D1DataType
=
Acc1BiasDataType
;
// TODO: implement bias combination
static_assert
(
NumAcc1Bias
==
0
,
"
Bias addition is unimplemented"
);
static_assert
(
std
::
is_void
<
D1DataType
>::
value
,
"Acc1
Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
;
...
...
@@ -718,8 +718,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InputDataType
*
p_a_grid
,
Argument
(
const
InputDataType
*
p_a_grid
,
const
InputDataType
*
p_b_grid
,
ZDataType
*
p_z_grid
,
const
InputDataType
*
p_b1_grid
,
...
...
@@ -729,8 +728,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
D0DataType
*
p_acc0_biases
,
const
D1DataType
*
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -742,11 +741,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1143,8 +1142,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InputDataType
*
p_a
,
static
auto
MakeArgument
(
const
InputDataType
*
p_a
,
const
InputDataType
*
p_b
,
ZDataType
*
p_z
,
const
InputDataType
*
p_b1
,
...
...
@@ -1154,8 +1153,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType
*
p_qgrad_grid
,
OutputDataType
*
p_kgrad_grid
,
OutputDataType
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
D0DataType
*
p_acc0_biases
,
const
D1DataType
*
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1167,11 +1166,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1232,8 +1231,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
void
*
p_acc0_biases
,
const
void
*
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1245,11 +1244,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>
&
acc0_biases_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
ck
::
index_t
>
&
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -1259,7 +1258,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InputDataType
*>
(
p_a
),
static_cast
<
const
InputDataType
*>
(
p_b
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
const
InputDataType
*>
(
p_b1
),
...
...
@@ -1269,8 +1269,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast
<
OutputDataType
*>
(
p_qgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_kgrad_grid
),
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
p_acc0_biases
,
// cast in struct Argument
p_acc1_biases
,
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_acc0_biases
)
,
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_biases
)
,
// cast in struct Argument
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment