Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
aea324d2
Commit
aea324d2
authored
Sep 20, 2023
by
letaoqin
Browse files
Merge branch 'mha-train-develop' into mha-train-develop-grad-bias
parents
73611570
f04ec574
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
230 additions
and
274 deletions
+230
-274
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+7
-9
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+7
-9
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+7
-9
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+190
-241
include/ck/utility/philox_rand.hpp
include/ck/utility/philox_rand.hpp
+13
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
...rary/reference_tensor_operation/cpu/reference_dropout.hpp
+6
-6
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
aea324d2
...
...
@@ -107,8 +107,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
...
...
@@ -133,8 +131,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
V_K0
=
Gemm1NPerBlock
/
KPerBlock
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -1541,8 +1539,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
...
...
@@ -1889,7 +1887,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
...
@@ -1958,7 +1956,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tensor_buffer
;
...
...
@@ -1968,7 +1966,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
aea324d2
...
...
@@ -98,8 +98,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
...
...
@@ -119,8 +117,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
constexpr
auto
V_K0
=
KPerBlock
/
V_K1
/
V_K2
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -1533,8 +1531,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
...
...
@@ -1852,7 +1850,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
...
@@ -1902,7 +1900,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tensor_buffer
;
...
...
@@ -1912,7 +1910,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
aea324d2
...
...
@@ -106,8 +106,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
...
...
@@ -132,8 +130,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
V_K0
=
Gemm1NPerBlock
/
KPerBlock
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -1599,8 +1597,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
...
...
@@ -1947,7 +1945,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
...
@@ -1997,7 +1995,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tensor_buffer
;
...
...
@@ -2007,7 +2005,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
aea324d2
...
...
@@ -113,8 +113,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
...
...
@@ -134,17 +132,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static
constexpr
auto
mfma
=
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
// 16
static
constexpr
auto
DropoutMThread
=
DropoutTile
;
// 16
static
constexpr
auto
DropoutTilePerXdl
=
NPerXdl
/
DropoutTile
;
// 2
static
constexpr
auto
DropoutStep
=
Number
<
DropoutStepValue
>
{};
// 1 2 4
static
constexpr
auto
DropoutNRepeat
=
Number
<
math
::
integer_divide_ceil
(
DropoutStep
,
DropoutTilePerXdl
)
>
{};
// 1 1 2
static
constexpr
auto
DropoutGroupPerTile
=
Number
<
mfma
.
num_groups_per_blk
/
DropoutTilePerXdl
>
{};
// 2
static
constexpr
auto
DropoutStepPerXdl
=
Number
<
math
::
min
(
DropoutStep
,
DropoutTilePerXdl
)
>
{};
// 1 2 2
// get_random_16x8() generates 16 random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
// 32
static
constexpr
auto
DropoutStep
=
Number
<
DropoutStepValue
>
{};
// 1 2
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -152,51 +142,45 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// C desc for source in gridwise copy
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
(
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
M0
=
M
/
MPerBlock
;
const
auto
N0
=
N
/
(
Dropout
NRepeat
*
NPerXdl
);
const
auto
N0
=
N
/
(
Dropout
Step
*
NPerXdl
);
constexpr
auto
M1
=
MXdlPerWave
;
constexpr
auto
N1
=
Dropout
NRepeat
;
constexpr
auto
N1
=
Dropout
Step
;
constexpr
auto
M2
=
Gemm0MWaves
;
constexpr
auto
N2
=
Gemm0NWaves
;
constexpr
auto
M3
=
DropoutTilePerXdl
;
constexpr
auto
N3
=
DropoutStepPerXdl
;
constexpr
auto
M4
=
DropoutTile
;
constexpr
auto
N4
=
DropoutGroupPerTile
;
constexpr
auto
N5
=
mfma
.
num_input_blks
;
constexpr
auto
N6
=
mfma
.
group_size
;
constexpr
auto
M3
=
DropoutTile
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
,
N5
,
N6
))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
8
>
{},
Sequence
<
1
,
3
,
5
,
7
,
9
,
10
,
11
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__host__
__device__
static
constexpr
auto
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
()
__host__
__device__
static
constexpr
auto
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
M0
=
MXdlPerWave
;
constexpr
auto
N0
=
Dropout
NRepeat
;
constexpr
auto
N0
=
Dropout
Step
;
constexpr
auto
M1
=
Gemm0MWaves
;
constexpr
auto
N1
=
Gemm0NWaves
;
constexpr
auto
M2
=
DropoutTilePerXdl
;
constexpr
auto
N2
=
DropoutStepPerXdl
;
constexpr
auto
M3
=
DropoutTile
;
constexpr
auto
N3
=
DropoutGroupPerTile
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
constexpr
auto
M2
=
DropoutTile
;
constexpr
auto
N2
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N3
=
mfma
.
num_input_blks
;
constexpr
auto
N4
=
mfma
.
group_size
;
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
m3_
n3_n4
_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M0
,
N0
,
M1
,
N1
,
M2
,
N2
,
M3
,
N3
,
N4
,
N5
));
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M0
,
N0
,
M1
,
N1
,
M2
,
N2
,
N3
,
N4
));
return
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
m3_
n3_n4
_n5
;
return
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
;
}
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
...
...
@@ -317,7 +301,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
const
index_t
z_block_bytes_end
=
SharedMemTrait
::
z_shuffle_block_space_size
*
sizeof
(
u
shor
t
);
SharedMemTrait
::
z_shuffle_block_space_size
*
sizeof
(
u
int8_
t
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
...
...
@@ -468,8 +452,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
(
ZGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
struct
SharedMemTrait
{
...
...
@@ -507,10 +491,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
// LDS allocation for Z shuffle in LDS
static
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
m3_
n3_n4
_n5
=
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_
M3_
N3_N4
_N5
();
static
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
static
constexpr
auto
z_shuffle_block_space_size
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_
m3_
n3_n4
_n5
.
GetElementSpaceSize
();
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
();
};
template
<
bool
HasMainKBlockLoop
,
...
...
@@ -538,12 +522,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
u
shor
t
p_dropout_in_
16bits
,
const
u
int8_
t
p_dropout_in_
uint8_t
,
FloatGemmAcc
p_dropout_rescale
,
ck
::
philox
&
ph
,
const
index_t
z_random_matrix_offset
,
...
...
@@ -894,7 +878,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
p_dropout_rescale
};
p_dropout_in_
uint8_t
,
p_dropout_rescale
};
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
...
...
@@ -992,26 +976,22 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
wave_m_n_id
[
I0
],
// NInputIndex
0
));
// register number
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
DropoutNRepeat
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
I1
,
DropoutStepPerXdl
,
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
DropoutStep
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
DropoutGroupPerTile
,
n2
,
n3
,
n4
));
// RegisterNum
constexpr
auto
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
DropoutNRepeat
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
I1
,
DropoutStepPerXdl
,
DropoutGroupPerTile
,
constexpr
auto
z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
DropoutStep
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
n2
,
n3
,
n4
,
// RegisterNum
m2
));
...
...
@@ -1020,180 +1000,150 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockId
m0
,
// MRepeat
DropoutNRepeat
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
I1
,
DropoutStepPerXdl
,
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockId
m0
,
// MRepeat
DropoutStep
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
DropoutGroupPerTile
,
n2
,
n3
,
n4
));
// RegisterNum
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
();
constexpr
auto
ZM0
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I0
);
// 1
constexpr
auto
ZN0
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I1
);
// 1 1 2
constexpr
auto
ZM1
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I2
);
// 4
constexpr
auto
ZN1
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I3
);
// 1
constexpr
auto
ZM2
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I4
);
// 2
constexpr
auto
ZN2
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I5
);
// 1 2 2
constexpr
auto
ZM3
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I6
);
// 16
constexpr
auto
ZN3
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I7
);
// 2
constexpr
auto
ZN4
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I8
);
// 2
constexpr
auto
ZN5
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I9
);
// 4
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
transform_tensor_descriptor
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
make_pass_through_transform
(
ZM0
),
make_pass_through_transform
(
ZN0
),
make_pass_through_transform
(
ZM1
),
make_pass_through_transform
(
ZN1
),
make_pass_through_transform
(
ZM2
),
make_pass_through_transform
(
ZN2
),
make_unmerge_transform
(
make_tuple
(
ZM3
/
ZN4
/
ZN5
,
ZN4
,
ZN5
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
ZN3
,
ZN4
,
ZN5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
,
8
,
9
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
,
7
,
8
>
{},
Sequence
<
9
>
{}));
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
ZM0
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
// 1
constexpr
auto
ZN0
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
// 1 2
constexpr
auto
ZM1
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
// 4
constexpr
auto
ZN1
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
// 1
constexpr
auto
ZN2
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
// 4
constexpr
auto
ZN3
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
// 2
constexpr
auto
ZN4
=
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
// 4
constexpr
auto
z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_pass_through_transform
(
ZM0
),
make_pass_through_transform
(
ZN0
),
make_pass_through_transform
(
ZM1
),
make_pass_through_transform
(
ZN1
),
make_unmerge_transform
(
make_tuple
(
ZN2
,
ZN3
,
ZN4
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
ZN2
,
ZN3
,
ZN4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_
n2_
m3_m4_
m5_n3
.
GetElementSpaceSize
(),
u
int8_
t
,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_
n2
.
GetElementSpaceSize
(),
true
>
z_tensor_buffer
;
z_tensor_buffer
.
Clear
();
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
auto
z_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ushort
*>
(
p_shared
),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
auto
z_tmp_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
m0
,
// MRepeat
DropoutNRepeat
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
I1
,
DropoutStepPerXdl
,
m2
,
DropoutGroupPerTile
,
n3
,
n4
>
,
// RegisterNum
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
// MRepeat
0
,
// NRepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
]
/
DropoutMThread
,
0
,
wave_m_n_id
[
I1
]
%
DropoutMThread
,
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
z_shuffle_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
ushort
,
ushort
,
decltype
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
Sequence
<
m0
,
DropoutNRepeat
,
m1
,
n1
,
I1
,
DropoutStepPerXdl
,
DropoutGroupPerTile
,
n3
,
n4
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
1
,
1
,
true
>
{
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
0
,
// MRepeat
0
,
// NRepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
]
/
DropoutMThread
,
0
,
0
,
wave_m_n_id
[
I0
],
0
,
wave_m_n_id
[
I1
]
%
DropoutMThread
)};
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
DropoutNRepeat
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
I1
,
DropoutStepPerXdl
,
m2
,
DropoutGroupPerTile
,
n3
,
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
>
,
11
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6
,
make_multi_index
(
block_work_idx_m
,
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
]
/
DropoutMThread
,
0
,
wave_m_n_id
[
I1
]
%
DropoutMThread
,
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
static_cast
<
uint8_t
*>
(
p_shared
),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
auto
z_tmp_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
uint8_t
,
uint8_t
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
m0
,
// MRepeat
DropoutStep
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
n2
,
n3
,
n4
>
,
// RegisterNum
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
0
,
// MRepeat
0
,
// NRepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
z_shuffle_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
uint8_t
,
uint8_t
,
decltype
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
Sequence
<
m0
,
DropoutStep
,
m1
,
n1
,
n2
,
n3
,
n4
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
1
,
true
>
{
z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
// MRepeat
0
,
// NRepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
0
,
wave_m_n_id
[
I0
],
0
,
wave_m_n_id
[
I1
])};
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
uint8_t
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
DropoutStep
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
n2
,
n3
,
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx_m
,
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
if
constexpr
(
Deterministic
)
{
...
...
@@ -1321,8 +1271,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
blockwise_softmax
.
Run
(
acc_thread_buf
,
workspace_buf
);
constexpr
auto
iterator_offset
=
Number
<
8
*
DropoutStep
>
{};
constexpr
auto
iterator_step
=
Number
<
n0
*
n1
*
n2
*
n3
*
n4
/
8
/
DropoutStep
>
{};
constexpr
auto
iterator_offset
=
Number
<
16
*
DropoutStep
>
{};
constexpr
auto
iterator_step
=
Number
<
m0
*
n0
*
n1
*
n2
*
n3
*
n4
/
16
/
DropoutStep
>
{};
if
constexpr
(
IsDropout
)
// dropout
{
...
...
@@ -1343,18 +1293,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
decltype
(
DropoutTile
)>(
ph
,
global_elem_id
,
z_tensor_buffer
);
z_tmp_thread_copy_vgpr_to_lds
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tensor_buffer
,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_block_buf
);
z_tmp_thread_copy_vgpr_to_lds
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tensor_buffer
,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
z_block_buf
);
z_shuffle_thread_copy_lds_to_vgpr
.
Run
(
z_shuffle_block_desc_m0_n0_m1_n1_m2_
n2_
m3_m4_
m5_n3
,
z_shuffle_block_desc_m0_n0_m1_n1_m2_m3_m4_
n2
,
z_block_buf
,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_
n2_
m3_m4_
m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_shuffle_thread_desc_m0_n0_m1_n1_m2_m3_m4_
n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tensor_buffer
);
blockwise_dropout
.
template
ApplyDropoutWithZ
<
decltype
(
acc_thread_buf
),
...
...
@@ -1367,14 +1316,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
if
(
p_z_grid
&&
(
gemm1_n_block_data_idx_on_grid
==
0
))
{
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tensor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
});
}
...
...
include/ck/utility/philox_rand.hpp
View file @
aea324d2
...
...
@@ -84,6 +84,19 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
__device__
void
get_random_16x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
)
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp_ph
.
x
;
out_tmp
[
1
]
=
tmp_ph
.
y
;
out_tmp
[
2
]
=
tmp_ph
.
z
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
__device__
void
get_random_4x16
(
ushort
*
out
,
const
unsigned
long
long
subsequence
)
{
uint4
tmp_ph
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
View file @
aea324d2
...
...
@@ -25,19 +25,19 @@ struct ReferenceDropout : public device::BaseOperator
Argument
(
const
Tensor
<
RefDataType
>&
ref
,
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
RefDataType
p_dropout_in_
16bits
,
RefDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
:
ref_
(
ref
),
in_
(
in
),
out_
(
out
),
p_dropout_in_
16bits
_
(
p_dropout_in_
16bits
),
p_dropout_in_
uint8_t
_
(
p_dropout_in_
uint8_t
),
rp_dropout_
(
rp_dropout
)
{
}
const
Tensor
<
RefDataType
>&
ref_
;
const
Tensor
<
InDataType
>&
in_
;
Tensor
<
OutDataType
>&
out_
;
RefDataType
p_dropout_in_
16bits
_
;
RefDataType
p_dropout_in_
uint8_t
_
;
float
rp_dropout_
;
};
...
...
@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
arg
.
ref_
(
idx
)
<=
arg
.
p_dropout_in_
16bits
_
arg
.
ref_
(
idx
)
<=
arg
.
p_dropout_in_
uint8_t
_
?
ck
::
type_convert
<
OutDataType
>
(
ck
::
type_convert
<
float
>
(
arg
.
in_
(
idx
))
*
ck
::
type_convert
<
float
>
(
arg
.
rp_dropout_
))
:
0
;
...
...
@@ -74,10 +74,10 @@ struct ReferenceDropout : public device::BaseOperator
static
auto
MakeArgument
(
const
Tensor
<
RefDataType
>&
ref
,
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
RefDataType
p_dropout_in_
16bits
,
RefDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
{
return
Argument
{
ref
,
in
,
out
,
p_dropout_in_
16bits
,
rp_dropout
};
return
Argument
{
ref
,
in
,
out
,
p_dropout_in_
uint8_t
,
rp_dropout
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment