Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3b58c3ec
Commit
3b58c3ec
authored
Feb 07, 2023
by
guangzlu
Browse files
saved dropout random number in gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
parent
cc974f0f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
167 additions
and
2 deletions
+167
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
+167
-2
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
View file @
3b58c3ec
...
...
@@ -35,6 +35,7 @@ template <typename FloatAB,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
LSEGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -116,6 +117,65 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// C desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
////=> for z use
{
constexpr
auto
mfma
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
WaveSize
/
MPerXdl
,
MPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
...
...
@@ -319,10 +379,13 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
...
...
@@ -367,6 +430,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
FloatLSE
*
__restrict__
p_lse_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
...
...
@@ -379,6 +443,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
...
...
@@ -782,6 +848,80 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
///////////////////=>z for dropout
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// registerNum
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
n4
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
///////////////////=>z for dropout
do
{
auto
n_block_data_idx_on_grid
=
...
...
@@ -876,9 +1016,34 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
if
constexpr
(
IsDropout
)
// dropout
{
blockwise_dropout
.
ApplyDropout
(
acc_thread_buf
,
ph
);
// save z to global
if
(
p_z_grid
)
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
true
>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
}
else
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
true
>(
acc_thread_buf
,
ph
);
}
}
//if constexpr(IsDropout) // dropout
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
// TODO: may convert to log domain
running_max_new
=
mathext
::
max
(
max
,
running_max
);
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment