Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
aea324d2
Commit
aea324d2
authored
Sep 20, 2023
by
letaoqin
Browse files
Merge branch 'mha-train-develop' into mha-train-develop-grad-bias
parents
73611570
f04ec574
Changes
26
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
230 additions
and
274 deletions
+230
-274
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+7
-9
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+7
-9
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+7
-9
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+190
-241
include/ck/utility/philox_rand.hpp
include/ck/utility/philox_rand.hpp
+13
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
...rary/reference_tensor_operation/cpu/reference_dropout.hpp
+6
-6
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
aea324d2
...
@@ -107,8 +107,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -107,8 +107,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// K1 should be Number<...>
...
@@ -133,8 +131,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -133,8 +131,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
V_K0
=
Gemm1NPerBlock
/
KPerBlock
;
static
constexpr
auto
V_K0
=
Gemm1NPerBlock
/
KPerBlock
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -1541,8 +1539,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1541,8 +1539,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
rp_dropout
);
...
@@ -1889,7 +1887,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1889,7 +1887,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
@@ -1958,7 +1956,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1958,7 +1956,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
n2
));
// NPerXdl
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
true
>
z_tensor_buffer
;
z_tensor_buffer
;
...
@@ -1968,7 +1966,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1968,7 +1966,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
aea324d2
...
@@ -98,8 +98,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -98,8 +98,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
static
constexpr
auto
WaveSize
=
64
;
...
@@ -119,8 +117,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -119,8 +117,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
constexpr
auto
V_K0
=
KPerBlock
/
V_K1
/
V_K2
;
static
constexpr
auto
V_K0
=
KPerBlock
/
V_K1
/
V_K2
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -1533,8 +1531,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1533,8 +1531,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
rp_dropout
);
...
@@ -1852,7 +1850,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1852,7 +1850,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
@@ -1902,7 +1900,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1902,7 +1900,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
n2
));
// NPerXdl
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
true
>
z_tensor_buffer
;
z_tensor_buffer
;
...
@@ -1912,7 +1910,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1912,7 +1910,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
aea324d2
...
@@ -106,8 +106,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -106,8 +106,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// K1 should be Number<...>
...
@@ -132,8 +130,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -132,8 +130,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
V_K0
=
Gemm1NPerBlock
/
KPerBlock
;
static
constexpr
auto
V_K0
=
Gemm1NPerBlock
/
KPerBlock
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -1599,8 +1597,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1599,8 +1597,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
rp_dropout
);
...
@@ -1947,7 +1945,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1947,7 +1945,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
@@ -1997,7 +1995,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1997,7 +1995,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
n2
));
// NPerXdl
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
true
>
z_tensor_buffer
;
z_tensor_buffer
;
...
@@ -2007,7 +2005,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -2007,7 +2005,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
aea324d2
This diff is collapsed.
Click to expand it.
include/ck/utility/philox_rand.hpp
View file @
aea324d2
...
@@ -84,6 +84,19 @@ class philox
...
@@ -84,6 +84,19 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
}
__device__
void
get_random_16x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
)
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp_ph
.
x
;
out_tmp
[
1
]
=
tmp_ph
.
y
;
out_tmp
[
2
]
=
tmp_ph
.
z
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
__device__
void
get_random_4x16
(
ushort
*
out
,
const
unsigned
long
long
subsequence
)
__device__
void
get_random_4x16
(
ushort
*
out
,
const
unsigned
long
long
subsequence
)
{
{
uint4
tmp_ph
;
uint4
tmp_ph
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
View file @
aea324d2
...
@@ -25,19 +25,19 @@ struct ReferenceDropout : public device::BaseOperator
...
@@ -25,19 +25,19 @@ struct ReferenceDropout : public device::BaseOperator
Argument
(
const
Tensor
<
RefDataType
>&
ref
,
Argument
(
const
Tensor
<
RefDataType
>&
ref
,
const
Tensor
<
InDataType
>&
in
,
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
Tensor
<
OutDataType
>&
out
,
RefDataType
p_dropout_in_
16bits
,
RefDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
:
ref_
(
ref
),
:
ref_
(
ref
),
in_
(
in
),
in_
(
in
),
out_
(
out
),
out_
(
out
),
p_dropout_in_
16bits
_
(
p_dropout_in_
16bits
),
p_dropout_in_
uint8_t
_
(
p_dropout_in_
uint8_t
),
rp_dropout_
(
rp_dropout
)
rp_dropout_
(
rp_dropout
)
{
{
}
}
const
Tensor
<
RefDataType
>&
ref_
;
const
Tensor
<
RefDataType
>&
ref_
;
const
Tensor
<
InDataType
>&
in_
;
const
Tensor
<
InDataType
>&
in_
;
Tensor
<
OutDataType
>&
out_
;
Tensor
<
OutDataType
>&
out_
;
RefDataType
p_dropout_in_
16bits
_
;
RefDataType
p_dropout_in_
uint8_t
_
;
float
rp_dropout_
;
float
rp_dropout_
;
};
};
...
@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator
...
@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator
{
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
self
(
idx
)
=
arg
.
ref_
(
idx
)
<=
arg
.
p_dropout_in_
16bits
_
arg
.
ref_
(
idx
)
<=
arg
.
p_dropout_in_
uint8_t
_
?
ck
::
type_convert
<
OutDataType
>
(
ck
::
type_convert
<
float
>
(
arg
.
in_
(
idx
))
*
?
ck
::
type_convert
<
OutDataType
>
(
ck
::
type_convert
<
float
>
(
arg
.
in_
(
idx
))
*
ck
::
type_convert
<
float
>
(
arg
.
rp_dropout_
))
ck
::
type_convert
<
float
>
(
arg
.
rp_dropout_
))
:
0
;
:
0
;
...
@@ -74,10 +74,10 @@ struct ReferenceDropout : public device::BaseOperator
...
@@ -74,10 +74,10 @@ struct ReferenceDropout : public device::BaseOperator
static
auto
MakeArgument
(
const
Tensor
<
RefDataType
>&
ref
,
static
auto
MakeArgument
(
const
Tensor
<
RefDataType
>&
ref
,
const
Tensor
<
InDataType
>&
in
,
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
Tensor
<
OutDataType
>&
out
,
RefDataType
p_dropout_in_
16bits
,
RefDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
{
{
return
Argument
{
ref
,
in
,
out
,
p_dropout_in_
16bits
,
rp_dropout
};
return
Argument
{
ref
,
in
,
out
,
p_dropout_in_
uint8_t
,
rp_dropout
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment