Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cd6e9903
Commit
cd6e9903
authored
Jun 06, 2023
by
guangzlu
Browse files
added dropout for bwd v5
parent
dc8e0148
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
226 additions
and
104 deletions
+226
-104
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v5.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v5.cpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
+23
-13
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
+199
-87
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v5.cpp
View file @
cd6e9903
...
...
@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
@@ -715,13 +715,13 @@ int run(int argc, char* argv[])
ck
::
index_t
M
=
500
;
// 512
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
1
;
// 54
ck
::
index_t
G1
=
1
;
// 16
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G1
=
6
;
// 16
bool
input_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.
1
;
float
p_drop
=
0.
2
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
View file @
cd6e9903
...
...
@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
...
...
@@ -70,8 +70,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -84,7 +84,9 @@ __global__ void
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
,
const
index_t
MRaw
,
const
index_t
NRaw
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -134,7 +136,7 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
...
...
@@ -143,6 +145,9 @@ __global__ void
c0_matrix_mask
,
p_drop
,
ph
,
g_idx
,
MRaw
,
NRaw
,
i
);
}
}
...
...
@@ -166,7 +171,7 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
...
...
@@ -175,6 +180,9 @@ __global__ void
c0_matrix_mask
,
p_drop
,
ph
,
g_idx
,
MRaw
,
NRaw
,
0
);
}
#else
...
...
@@ -831,8 +839,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
(
z_grid_desc_m_n_
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
(
z_grid_desc_m_n_
);
// Print();
}
...
...
@@ -892,8 +900,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
;
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
...
@@ -960,7 +968,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
...
...
@@ -994,7 +1002,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
...
...
@@ -1006,7 +1014,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
,
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
],
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
View file @
cd6e9903
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment