Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
78c1482a
Commit
78c1482a
authored
Jun 05, 2023
by
guangzlu
Browse files
add dorpout based on global position for bwd v4
parent
9ec592a6
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
326 additions
and
96 deletions
+326
-96
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+86
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
+20
-10
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
+220
-85
No files found.
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
78c1482a
...
...
@@ -124,7 +124,7 @@ struct BlockwiseDropout
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout_v1r1
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
)
ApplyDropout_v1r1
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
)
//
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
...
...
@@ -157,6 +157,43 @@ struct BlockwiseDropout
});
}
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropoutAttnBwd
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
index_t
MRaw
)
//
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout_v1r2
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
...
...
@@ -204,6 +241,54 @@ struct BlockwiseDropout
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropoutAttnBwdSaveZ
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
,
index_t
MRaw
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
}
ushort
tmp_id
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
tmp_id
[
i
*
4
+
j
]
=
element_global_1d_id
+
i
*
8
*
MRaw
;
}
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp_id
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout_v2
(
CThreadBuffer
&
in_thread_buf
,
ZThreadBuffer
&
z_thread_buf
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
View file @
78c1482a
...
...
@@ -39,7 +39,7 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
...
...
@@ -71,7 +71,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
...
...
@@ -85,7 +85,9 @@ __global__ void
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
,
const
index_t
MRaw
,
const
index_t
NRaw
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -144,6 +146,9 @@ __global__ void
c0_matrix_mask
,
p_drop
,
ph
,
g_idx
,
MRaw
,
NRaw
,
i
);
}
}
...
...
@@ -176,6 +181,9 @@ __global__ void
c0_matrix_mask
,
p_drop
,
ph
,
g_idx
,
MRaw
,
NRaw
,
0
);
}
#else
...
...
@@ -818,8 +826,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
(
z_grid_desc_m_n_
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
(
z_grid_desc_m_n_
);
// Print();
}
...
...
@@ -879,8 +887,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
;
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
...
@@ -943,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
...
...
@@ -977,7 +985,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
...
...
@@ -989,7 +997,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
,
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
],
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
View file @
78c1482a
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment