Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a39dd61f
Commit
a39dd61f
authored
Mar 09, 2023
by
danyao12
Browse files
refactor grouped bwd example and fix some bugs
parent
79f3caf8
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
186 additions
and
192 deletions
+186
-192
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
...ale_softmax_gemm/batched_multihead_attention_backward.cpp
+5
-6
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
...ale_softmax_gemm/grouped_multihead_attention_backward.cpp
+163
-169
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
+11
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
+7
-7
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
View file @
a39dd61f
...
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define RANGE_HDKO
2
// 0~2
#define RANGE_HDKO
1
// 0~2
#include <iostream>
#include <numeric>
...
...
@@ -523,7 +523,7 @@ int run(int argc, char* argv[])
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
alpha
=
std
::
stof
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
input_permute
=
std
::
stoi
(
argv
[
12
]);
...
...
@@ -540,9 +540,9 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
...
...
@@ -678,7 +678,6 @@ int run(int argc, char* argv[])
// = 0
}
// calculate y & log-sum-exp beforehand
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
View file @
a39dd61f
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_
pt
1.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_
v
1.hpp
View file @
a39dd61f
...
...
@@ -98,7 +98,7 @@ __global__ void
unsigned
short
*
z_matrix_ptr
=
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
...
...
@@ -211,7 +211,7 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
...
@@ -223,7 +223,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
;
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
;
struct
ProblemDesc
{
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
...
...
@@ -448,7 +448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
...
...
@@ -534,7 +533,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
<
DataType
,
// TODO: distinguish A/B datatype
GemmDataType
,
GemmAccDataType
,
...
...
@@ -711,7 +710,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
}
grid_size_
=
0
;
for
(
std
::
size
_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index
_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
const
auto
p_a_grid
=
static_cast
<
const
DataType
*>
(
p_As
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
DataType
*>
(
p_Bs
[
i
]);
...
...
@@ -895,7 +894,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// for(std::size_t i = 0; i < arg.group_count_; i++)
// {
// const auto K =
// arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
// arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
// arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
// const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
// all_has_main_k_block_loop &= y;
// some_has_main_k_block_loop |= y;
...
...
@@ -976,7 +976,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
return
false
;
}
for
(
std
::
size
_t
i
=
0
;
i
<
arg
.
group_count_
;
i
++
)
for
(
index
_t
i
=
0
;
i
<
arg
.
group_count_
;
i
++
)
{
// TODO: Check if tensor specialization & strides mismatch
const
auto
&
kernel_arg
=
arg
.
group_kernel_args_
[
i
];
...
...
@@ -986,7 +986,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
device_arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
{
...
...
@@ -1160,7 +1161,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1"
str
<<
"DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_
V
1"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
a39dd61f
...
...
@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_
v
2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_
pt
2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -98,7 +98,7 @@ __global__ void
unsigned
short
*
z_matrix_ptr
=
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
...
...
@@ -703,7 +703,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
grid_size_
=
0
;
for
(
std
::
size
_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index
_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
const
auto
p_a_grid
=
static_cast
<
const
DataType
*>
(
p_As
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
DataType
*>
(
p_Bs
[
i
]);
...
...
@@ -884,10 +884,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
bool
all_has_main_k_block_loop
=
true
;
bool
some_has_main_k_block_loop
=
false
;
for
(
std
::
size
_t
i
=
0
;
i
<
arg
.
group_count_
;
i
++
)
for
(
index
_t
i
=
0
;
i
<
arg
.
group_count_
;
i
++
)
{
const
auto
K
=
arg
.
group_kernel_args_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
group_kernel_args_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
auto
K
=
arg
.
group_kernel_args_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
group_kernel_args_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
bool
y
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
all_has_main_k_block_loop
&=
y
;
some_has_main_k_block_loop
|=
y
;
...
...
@@ -968,7 +968,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
return
false
;
}
for
(
std
::
size
_t
i
=
0
;
i
<
arg
.
group_count_
;
i
++
)
for
(
index
_t
i
=
0
;
i
<
arg
.
group_count_
;
i
++
)
{
// TODO: Check if tensor specialization & strides mismatch
const
auto
&
kernel_arg
=
arg
.
group_kernel_args_
[
i
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment