Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
12dcba20
Unverified
Commit
12dcba20
authored
Sep 11, 2023
by
Dan Yao
Committed by
GitHub
Sep 11, 2023
Browse files
Merge pull request #903 from ROCmSoftwarePlatform/mha-train-develop-dropout8bit
Mha train develop dropout8bit
parents
172835a5
e9707881
Changes
24
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
216 additions
and
254 deletions
+216
-254
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+7
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+190
-241
include/ck/utility/philox_rand.hpp
include/ck/utility/philox_rand.hpp
+13
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
...rary/reference_tensor_operation/cpu/reference_dropout.hpp
+6
-6
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
12dcba20
...
@@ -130,8 +130,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -130,8 +130,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
V_K0
=
Gemm1NPerBlock
/
KPerBlock
;
static
constexpr
auto
V_K0
=
Gemm1NPerBlock
/
KPerBlock
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -1553,8 +1553,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1553,8 +1553,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
rp_dropout
);
...
@@ -1901,7 +1901,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1901,7 +1901,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
@@ -1951,7 +1951,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1951,7 +1951,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
n2
));
// NPerXdl
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
true
>
z_tensor_buffer
;
z_tensor_buffer
;
...
@@ -1961,7 +1961,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1961,7 +1961,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
12dcba20
This diff is collapsed.
Click to expand it.
include/ck/utility/philox_rand.hpp
View file @
12dcba20
...
@@ -84,6 +84,19 @@ class philox
...
@@ -84,6 +84,19 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
}
__device__
void
get_random_16x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
)
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp_ph
.
x
;
out_tmp
[
1
]
=
tmp_ph
.
y
;
out_tmp
[
2
]
=
tmp_ph
.
z
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
__device__
void
get_random_4x16
(
ushort
*
out
,
const
unsigned
long
long
subsequence
)
__device__
void
get_random_4x16
(
ushort
*
out
,
const
unsigned
long
long
subsequence
)
{
{
uint4
tmp_ph
;
uint4
tmp_ph
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
View file @
12dcba20
...
@@ -25,19 +25,19 @@ struct ReferenceDropout : public device::BaseOperator
...
@@ -25,19 +25,19 @@ struct ReferenceDropout : public device::BaseOperator
Argument
(
const
Tensor
<
RefDataType
>&
ref
,
Argument
(
const
Tensor
<
RefDataType
>&
ref
,
const
Tensor
<
InDataType
>&
in
,
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
Tensor
<
OutDataType
>&
out
,
RefDataType
p_dropout_in_
16bits
,
RefDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
:
ref_
(
ref
),
:
ref_
(
ref
),
in_
(
in
),
in_
(
in
),
out_
(
out
),
out_
(
out
),
p_dropout_in_
16bits
_
(
p_dropout_in_
16bits
),
p_dropout_in_
uint8_t
_
(
p_dropout_in_
uint8_t
),
rp_dropout_
(
rp_dropout
)
rp_dropout_
(
rp_dropout
)
{
{
}
}
const
Tensor
<
RefDataType
>&
ref_
;
const
Tensor
<
RefDataType
>&
ref_
;
const
Tensor
<
InDataType
>&
in_
;
const
Tensor
<
InDataType
>&
in_
;
Tensor
<
OutDataType
>&
out_
;
Tensor
<
OutDataType
>&
out_
;
RefDataType
p_dropout_in_
16bits
_
;
RefDataType
p_dropout_in_
uint8_t
_
;
float
rp_dropout_
;
float
rp_dropout_
;
};
};
...
@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator
...
@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator
{
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
self
(
idx
)
=
arg
.
ref_
(
idx
)
<=
arg
.
p_dropout_in_
16bits
_
arg
.
ref_
(
idx
)
<=
arg
.
p_dropout_in_
uint8_t
_
?
ck
::
type_convert
<
OutDataType
>
(
ck
::
type_convert
<
float
>
(
arg
.
in_
(
idx
))
*
?
ck
::
type_convert
<
OutDataType
>
(
ck
::
type_convert
<
float
>
(
arg
.
in_
(
idx
))
*
ck
::
type_convert
<
float
>
(
arg
.
rp_dropout_
))
ck
::
type_convert
<
float
>
(
arg
.
rp_dropout_
))
:
0
;
:
0
;
...
@@ -74,10 +74,10 @@ struct ReferenceDropout : public device::BaseOperator
...
@@ -74,10 +74,10 @@ struct ReferenceDropout : public device::BaseOperator
static
auto
MakeArgument
(
const
Tensor
<
RefDataType
>&
ref
,
static
auto
MakeArgument
(
const
Tensor
<
RefDataType
>&
ref
,
const
Tensor
<
InDataType
>&
in
,
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
Tensor
<
OutDataType
>&
out
,
RefDataType
p_dropout_in_
16bits
,
RefDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
{
{
return
Argument
{
ref
,
in
,
out
,
p_dropout_in_
16bits
,
rp_dropout
};
return
Argument
{
ref
,
in
,
out
,
p_dropout_in_
uint8_t
,
rp_dropout
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment