Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d0cd6886
"src/include/functional3.hip.hpp" did not exist on "3a6044aa84e0836ddc083233b7d616c19adeb677"
Commit
d0cd6886
authored
Sep 07, 2023
by
guangzlu
Browse files
seperate inheriting for bwd qloop v1 and v2
parent
4b517872
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
96 additions
and
35 deletions
+96
-35
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
+62
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+17
-17
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+17
-17
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
View file @
d0cd6886
...
@@ -139,7 +139,68 @@ template <index_t NumDimG,
...
@@ -139,7 +139,68 @@ template <index_t NumDimG,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
struct
DeviceBatchedMultiheadAttentionBackward
:
public
BaseOperator
struct
DeviceBatchedMultiheadAttentionBackwardQloopV1
:
public
BaseOperator
{
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_z
,
const
void
*
p_b1
,
const
void
*
p_c
,
const
void
*
p_lse
,
const
void
*
p_ygrad_grid
,
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
// lse_gs_ms_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
InputDataType
,
typename
OutputDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceBatchedMultiheadAttentionBackwardQloopV2
:
public
BaseOperator
{
{
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
d0cd6886
...
@@ -288,23 +288,23 @@ template <index_t NumDimG,
...
@@ -288,23 +288,23 @@ template <index_t NumDimG,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
struct
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
:
public
DeviceBatchedMultiheadAttentionBackward
<
NumDimG
,
:
public
DeviceBatchedMultiheadAttentionBackward
QloopV1
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
InputDataType
,
InputDataType
,
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>
MaskingSpec
>
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
"Number of dimension must be greater than 0"
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
d0cd6886
...
@@ -294,23 +294,23 @@ template <index_t NumDimG,
...
@@ -294,23 +294,23 @@ template <index_t NumDimG,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
struct
DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
:
public
DeviceBatchedMultiheadAttentionBackward
<
NumDimG
,
:
public
DeviceBatchedMultiheadAttentionBackward
QloopV2
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
InputDataType
,
InputDataType
,
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>
MaskingSpec
>
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
"Number of dimension must be greater than 0"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment