Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cd6e9903
Commit
cd6e9903
authored
Jun 06, 2023
by
guangzlu
Browse files
added dropout for bwd v5
parent
dc8e0148
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
226 additions
and
104 deletions
+226
-104
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v5.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v5.cpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
+23
-13
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
+199
-87
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v5.cpp
View file @
cd6e9903
...
@@ -32,7 +32,7 @@ Kernel outputs:
...
@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -715,13 +715,13 @@ int run(int argc, char* argv[])
...
@@ -715,13 +715,13 @@ int run(int argc, char* argv[])
ck
::
index_t
M
=
500
;
// 512
ck
::
index_t
M
=
500
;
// 512
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
1
;
// 54
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G1
=
1
;
// 16
ck
::
index_t
G1
=
6
;
// 16
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.
1
;
float
p_drop
=
0.
2
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
View file @
cd6e9903
...
@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
...
@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
typename
LSEGridDescriptor_M
,
...
@@ -70,8 +70,8 @@ __global__ void
...
@@ -70,8 +70,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -84,7 +84,9 @@ __global__ void
...
@@ -84,7 +84,9 @@ __global__ void
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
,
const
index_t
MRaw
,
const
index_t
NRaw
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -134,7 +136,7 @@ __global__ void
...
@@ -134,7 +136,7 @@ __global__ void
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
lse_grid_desc_m
,
...
@@ -143,6 +145,9 @@ __global__ void
...
@@ -143,6 +145,9 @@ __global__ void
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
,
ph
,
g_idx
,
MRaw
,
NRaw
,
i
);
i
);
}
}
}
}
...
@@ -166,7 +171,7 @@ __global__ void
...
@@ -166,7 +171,7 @@ __global__ void
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
lse_grid_desc_m
,
...
@@ -175,6 +180,9 @@ __global__ void
...
@@ -175,6 +180,9 @@ __global__ void
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
,
ph
,
g_idx
,
MRaw
,
NRaw
,
0
);
0
);
}
}
#else
#else
...
@@ -831,8 +839,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -831,8 +839,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
(
z_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
(
z_grid_desc_m_n_
);
// Print();
// Print();
}
}
...
@@ -892,8 +900,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -892,8 +900,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
;
// block-to-c-tile map
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -960,7 +968,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -960,7 +968,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
LSEGridDesc_M
,
...
@@ -994,7 +1002,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -994,7 +1002,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
lse_grid_desc_m_
,
...
@@ -1006,7 +1014,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1006,7 +1014,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
,
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
],
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
View file @
cd6e9903
...
@@ -127,24 +127,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -127,24 +127,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// C desc for source in blockwise copy
// C desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
{
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
M
PerXdl
)),
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
M
3
,
M4
,
M5
)),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N
3
,
N4
,
N5
))),
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N
PerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
1
,
3
,
5
,
9
>
{}));
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
{
{
...
@@ -467,8 +469,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -467,8 +469,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
KGridDesc_N_K
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
KGridDesc_N_K
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
=
remove_cvref_t
<
decltype
(
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
(
ZGridDesc_M_N
{}))
>
;
// S / dP Gemm (type 1 rcc)
// S / dP Gemm (type 1 rcc)
struct
Gemm0
struct
Gemm0
...
@@ -1183,8 +1185,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1183,8 +1185,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
&
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
...
@@ -1194,6 +1196,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1194,6 +1196,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
ck
::
philox
&
ph
,
ck
::
philox
&
ph
,
const
index_t
g_idx
,
const
index_t
MRaw
,
const
index_t
NRaw
,
const
index_t
block_idx_n
)
const
index_t
block_idx_n
)
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
...
@@ -1558,47 +1563,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1558,47 +1563,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// z vgpr copy to global
// z vgpr copy to global
//
//
// z matrix threadwise desc
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
=
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
m0
,
// MRepeat
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// M
PerXdl
m2
,
// M
GroupNum
m3
,
//
NGroup
Num
m3
,
//
MInput
Num
m4
,
//
NInput
Num
m4
,
//
register
Num
n2
));
//
registerNum
n2
));
//
NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
.
GetElementSpaceSize
(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
.
GetElementSpaceSize
(),
true
>
true
>
z_tenor_buffer
;
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
z_tenor_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
// z matrix global desc
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
// ignore = p_z_tmp_grid;
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
// tmp buffer for shuffle
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
.
GetElementSpaceSize
());
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ushort
,
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ZDataType
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
m0
,
// MRepeat
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m2
,
// MPerXdl
...
@@ -1610,17 +1613,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1610,17 +1613,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1
,
// DstScalarPerVector
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
make_multi_index
(
block_work_idx_n
,
// MBlockId
make_multi_index
(
0
,
// MBlockId
0
,
// NBlockId
block_work_idx_n
,
// NBlockId
0
,
// mrepeat
0
,
// mrepeat
0
,
// nrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
,
// NInputIndex
0
),
wave_m_n_id
[
I1
]
),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
//
//
...
@@ -1743,8 +1746,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1743,8 +1746,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_grid_desc_mblock_mperblock_oblock_operblock
,
make_multi_index
(
1
,
0
,
0
,
0
));
y_grid_desc_mblock_mperblock_oblock_operblock
,
make_multi_index
(
1
,
0
,
0
,
0
));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
continue
;
continue
;
}
}
...
@@ -1891,35 +1894,144 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1891,35 +1894,144 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// save z to global
// save z to global
if
(
p_z_grid
)
if
(
p_z_grid
)
{
{
// P_dropped
// 8d thread_desc in thread scope
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
c_thread_lengths
=
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
M3
=
c_block_lengths
[
I5
];
constexpr
auto
M4
=
c_block_lengths
[
I6
];
constexpr
auto
N2
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto
global_elem_id_raw
=
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
auto
global_elem_id
=
(
global_elem_id_raw
%
4
)
*
MRaw
+
int
(
global_elem_id_raw
/
4
)
*
4
;
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
//}
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// if(get_thread_global_1d_id() == 0){
// printf("Acc0TileIterator::GetNumOfAccess() is %d \n",
// Acc0TileIterator::GetNumOfAccess()); printf("n0.value is %d \n", n0.value);
// printf("id_step is %d \n", id_step);
//}
// dropout
// z_tenor_buffer_tmp -> z_grid_buf_tmp -> shuffle -> z_tenor_buffer -> z_grid_buf
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_tenor_buffer
),
true
,
true
>(
decltype
(
n0
),
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
MRaw
);
decltype
(
i
)>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
z_grid_buf
);
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
});
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
-
n0
.
value
,
0
,
0
,
0
,
0
,
0
,
0
));
}
}
else
else
{
{
ignore
=
z_grid_buf
;
ignore
=
z_grid_buf
;
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
M3
=
c_block_lengths
[
I5
];
constexpr
auto
M4
=
c_block_lengths
[
I6
];
constexpr
auto
N2
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto
global_elem_id_raw
=
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
auto
global_elem_id
=
(
global_elem_id_raw
%
4
)
*
MRaw
+
int
(
global_elem_id_raw
/
4
)
*
4
;
// P_dropped
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
blockwise_dropout
s_slash_p_thread_buf
,
ph
);
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
MRaw
);
}
}
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
...
@@ -2183,8 +2295,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -2183,8 +2295,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_multi_index
(
1
,
0
,
0
,
0
));
make_multi_index
(
1
,
0
,
0
,
0
));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
while
(
++
gemm0_m_block_outer_index
<
num_gemm0_m_block_outer_loop
);
// end j loop
}
while
(
++
gemm0_m_block_outer_index
<
num_gemm0_m_block_outer_loop
);
// end j loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment