Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c798cff9
Commit
c798cff9
authored
Dec 20, 2022
by
Anthony Chang
Browse files
refactor gemm0
parent
db7f7bed
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
263 additions
and
325 deletions
+263
-325
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+4
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+259
-324
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
c798cff9
...
...
@@ -641,7 +641,10 @@ int run(int argc, char* argv[])
std
::
cout
<<
"Checking qgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
qgrad_gs_ms_ks_host_result
.
mData
);
qgrad_gs_ms_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking kgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
kgrad_gs_ns_ks_host_result
.
mData
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
c798cff9
...
...
@@ -446,26 +446,124 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct
PGradGemmTile_M_N_O
struct
SharedMemTrait
{
private:
static
constexpr
auto
ygrad
_block_desc_
o
0_m_
o
1
=
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a
_block_desc_
ak
0_m_
ak
1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
v
_block_desc_
o
0_n_
o
1
=
static
constexpr
auto
b
_block_desc_
bk
0_n_
bk
1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
p_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
static
constexpr
auto
ygrad_block_desc_m0_o_m1
=
VGradGemmTile_N_O_M
::
GetYGradBlockDescriptor_M0_O_M1
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{};
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_block_space_size_aligned
=
math
::
integer_least_multiple
(
p_block_desc_m0_n_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
p_block_space_offset
=
0
;
static
constexpr
auto
ygrad_block_space_offset
=
p_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
// P / dP Gemm (type 1 rcr)
struct
Gemm0
{
private:
static
constexpr
auto
a_block_desc
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b_block_desc
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
public:
template
<
typename
GridDesc_K0_M_K1
>
using
ABlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_M_K1
,
decltype
(
a_block_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
template
<
typename
GridDesc_K0_N_K1
>
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
GridDesc_K0_N_K1
,
decltype
(
b_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
public:
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
ygrad_block_desc_o0_m_o1
),
decltype
(
v_block_desc_o0_n_o1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
ygrad_block_desc_o0_m_o1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
v_block_desc_o0_n_o1
)),
decltype
(
a_block_desc
),
decltype
(
b_block_desc
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
...
...
@@ -474,8 +572,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
MXdlPerWave
,
NXdlPerWave
,
KPack
,
true
>
;
true
>
;
// TransposeC
static
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
};
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct
PGradGemmTile_M_N_O
{
// TODO ANT:
// Should have made all input tensors 2D and transform them into appropriate 3D form in
// kernel to make things more concise - if we can get the compiler to behave
template
<
typename
YGradGridDesc_M0_O_M1_
>
...
...
@@ -597,52 +703,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
};
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
p_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
static
constexpr
auto
ygrad_block_desc_m0_o_m1
=
VGradGemmTile_N_O_M
::
GetYGradBlockDescriptor_M0_O_M1
();
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{};
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_block_space_size_aligned
=
math
::
integer_least_multiple
(
p_block_desc_m0_n_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
p_block_space_offset
=
0
;
static
constexpr
auto
ygrad_block_space_offset
=
p_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
...
...
@@ -703,47 +763,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
;
}
// HACK: this force m/
gemm1_n
_block_data_idx_on_grid into SGPR
// HACK: this force m/
o
_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
gemm1_n
_block_data_idx_on_grid
=
const
index_t
o
_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
//
// set up P / dP Gemm (type 1 rcr)
//
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
//
// set up Gemm0
//
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
decltype
(
q_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
q_grid_desc_k0_m_k1
)>(
q_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
...
...
@@ -753,28 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
decltype
(
k_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_k0_n_k1
)>(
k_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
b_element_op
,
...
...
@@ -782,35 +800,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
auto
s_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
// TransposeC
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
true
>
{};
// TransposeC
auto
acc_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
s_slash_p_thread_buf
=
s_blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
...
...
@@ -821,8 +813,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
Gemm0
::
a_block_slice_copy_step
;
constexpr
auto
b_block_slice_copy_step
=
Gemm0
::
b_block_slice_copy_step
;
const
auto
a_block_reset_copy_step
=
make_multi_index
(
-
q_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
const
auto
b_block_reset_copy_step
=
...
...
@@ -839,12 +832,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
KPerBlock
);
//
// set up
O
/ dQ Gemm
// set up
Y
/ dQ Gemm
(type 2 rrr)
//
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr
auto
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
s_
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
m0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
n0
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
...
...
@@ -873,7 +866,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr
auto
AccN3
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLength
(
I6
);
s_
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLength
(
I6
);
constexpr
auto
A1ThreadSlice_K0_M_K1
=
make_tuple
(
Number
<
Gemm1KPerBlock
/
n4
/
AccN3
>
{},
Number
<
m0
*
m1
*
m2
>
{},
Number
<
n4
>
{});
...
...
@@ -889,47 +882,47 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A1 matrix blockwise copy
//
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
//
FloatGemmAcc,
//
DataType,
//
decltype(acc_thread_desc_k0_m_k1),
//
decltype(a1_thread_desc_k0_m_k1),
//
tensor_operation::element_wise::PassThrough,
//
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
//
Sequence<1, 0, 2>,
//
2,
//
n4>{tensor_operation::element_wise::PassThrough{}};
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
DataType
,
decltype
(
acc_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// B1 matrix blockwise copy
//
auto b1_blockwise_copy =
//
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
//
BElementwiseOperation,
//
tensor_operation::element_wise::PassThrough,
//
InMemoryDataOperationEnum::Set,
//
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
//
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
//
B1BlockTransferThreadClusterArrangeOrder,
//
DataType,
//
DataType,
//
decltype(v_grid_desc_n0_o_n1),
//
decltype(b1_block_desc_bk0_n_bk1),
//
B1BlockTransferSrcAccessOrder,
//
Sequence<1, 0, 2>,
//
B1BlockTransferSrcVectorDim,
//
2,
//
B1BlockTransferSrcScalarPerVector,
//
B1BlockTransferDstScalarPerVector_BK1,
//
1,
//
1,
//
B1ThreadTransferSrcResetCoordinateAfterRun,
//
true, // DstResetCoord
//
NumGemmKPrefetchStage>(
//
v_grid_desc_n0_o_n1,
//
make_multi_index(0,
gemm1_n
_block_data_idx_on_grid, 0),
//
b1_element_op,
//
b1_block_desc_bk0_n_bk1,
//
make_multi_index(0, 0, 0),
//
tensor_operation::element_wise::PassThrough{});
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
decltype
(
v_grid_desc_n0_o_n1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
v_grid_desc_n0_o_n1
,
make_multi_index
(
0
,
o
_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
...
...
@@ -984,8 +977,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// get acc0 8D thread cluster
constexpr
auto
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
()
/
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
s_
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
()
/
s_
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
tm0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I0
);
constexpr
auto
tn0
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I1
);
constexpr
auto
tm1
=
thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4
.
At
(
I2
);
...
...
@@ -1021,23 +1014,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
// Initialize C
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
acc1_thread_buf
.
Size
(),
true
>
c_thread_buf
;
c_thread_buf
.
Clear
();
// Initialize running sum and max of exponentiating row vectors
using
SoftmaxBuf
=
typename
decltype
(
blockwise_softmax
)
::
BufferType
;
SoftmaxBuf
running_sum
,
running_sum_new
,
running_max
,
running_max_new
;
running_sum
=
0
;
running_sum_new
=
0
;
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
lse_grid_desc_m
);
...
...
@@ -1047,7 +1023,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto
lse_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatLSE
>
(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
auto
acc0_thread_origin
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
auto
acc0_thread_origin
=
s_
blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
auto
lse_thread_copy_global_to_vgpr
=
...
...
@@ -1068,19 +1044,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc0_thread_origin
[
I4
])};
// mperxdl
//
//
dV
//
set up dV / dK Gemm (type 3 crr)
//
// P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr
auto
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
s_
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
p_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
constexpr
auto
p_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
s_
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
P_M0
=
p_block_lengths
[
I0
];
// repeats
constexpr
auto
P_N0
=
p_block_lengths
[
I1
];
constexpr
auto
P_M1
=
p_block_lengths
[
I2
];
// waves
...
...
@@ -1113,7 +1089,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
p_thread_origin_nd_idx_on_block
=
[
&
]()
{
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
s_
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
...
...
@@ -1184,17 +1160,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_thread_origin_nd_idx_on_block
[
I7
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
// Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0],
// p_block_slice_lengths_m0_n0_m1_n1[I1],
// I1,
// I1,
// I1,
// P_N2,
// I1,
// P_N4>{}
// .foo();
// 1, 4, 1, 1, 1, 4, 1, 4
constexpr
auto
sfc_p_m0_n0_m1_n1_m2_n2
=
SpaceFillingCurve
<
Sequence
<
P_M0
,
P_N0
,
P_M1
,
P_N1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
...
...
@@ -1229,7 +1194,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true
,
1
>
(
ygrad_grid_desc_m0_o_m1
,
make_multi_index
(
m_block_data_idx_on_grid
/
VGradGemmTile_N_O_M
::
YGrad_M1
,
gemm1_n
_block_data_idx_on_grid
,
o
_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
ygrad_block_desc_m0_o_m1
,
...
...
@@ -1292,7 +1257,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
index_t
n_thread_data_idx_on_grid
=
vgrad_thread_mtx_on_block_n_o
[
I0
];
const
index_t
o_thread_data_idx_on_grid
=
vgrad_thread_mtx_on_block_n_o
[
I1
]
+
gemm1_n
_block_data_idx_on_grid
;
vgrad_thread_mtx_on_block_n_o
[
I1
]
+
o
_block_data_idx_on_grid
;
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
VGrad_N0
,
VGrad_N1
,
VGrad_N2
))),
...
...
@@ -1375,8 +1340,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
#endif
//
//
dP
//
set up Y dot dY
//
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
...
...
@@ -1424,35 +1390,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
),
Sequence
<
1
,
m0
,
m1
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
m2
,
1
,
false
>
{
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
I0
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
//
// set up dP Gemm (type 1 rcr)
//
// transform input and output tensor descriptors
const
auto
ygrad_grid_desc_o0_m_o1
=
PGradGemmTile_M_N_O
::
MakeYGradGridDesc_O0_M_O1
(
ygrad_grid_desc_m0_o_m1
);
const
auto
v_grid_desc_o0_n_o1
=
PGradGemmTile_M_N_O
::
MakeVGridDesc_O0_N_O1
(
v_grid_desc_n0_o_n1
);
// dP Gemm A
matrix
blockwise copy
// dP Gemm A
position
blockwise copy
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
decltype
(
ygrad_grid_desc_o0_m_o1
),
decltype
(
a_block_desc_ak0_m_ak1
),
// reuse block buf
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
ygrad_grid_desc_o0_m_o1
)>(
ygrad_grid_desc_o0_m_o1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
...
...
@@ -1460,30 +1426,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dP Gemm B
matrix
blockwise copy
// dP Gemm B
position
blockwise copy
auto
pgrad_gemm_tile_v_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
decltype
(
v_grid_desc_o0_n_o1
),
decltype
(
b_block_desc_bk0_n_bk1
),
// reuse block buf
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
v_grid_desc_o0_n_o1
)>(
v_grid_desc_o0_n_o1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
tensor_operation
::
element_wise
::
PassThrough
{},
...
...
@@ -1491,7 +1436,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
pgrad_blockwise_gemm
=
typename
PGradGemmTile_M_N_O
::
BlockwiseGemm
{};
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
auto
pgrad_thread_buf
=
pgrad_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
pgrad_gemm_tile_ygrad_block_reset_copy_step
=
make_multi_index
(
-
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
),
0
,
0
);
...
...
@@ -1502,25 +1447,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
(
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
)
*
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I2
))
/
KPerBlock
);
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
),
Sequence
<
1
,
m0
,
m1
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
m2
,
1
,
false
>
{
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
I0
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
// clear accum buffers
y_dot_ygrad_thread_accum_buf
.
Clear
();
y_dot_ygrad_block_accum_buf
.
Clear
();
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds before accum\n");
if(hipBlockIdx_x == 0)
...
...
@@ -1533,7 +1459,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
//
//
dQ
//
set up dQ Gemm (type 2 rrr)
//
const
auto
k_grid_desc_n0_k_n1
=
QGradGemmTile_M_K_N
::
MakeKGridDesc_N0_K_N1
(
k_grid_desc_k0_n_k1
);
...
...
@@ -1578,7 +1504,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
k_grid_desc_n0_k_n1
,
make_multi_index
(
0
,
gemm1_n
_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
o
_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
...
...
@@ -1607,9 +1533,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
qgrad_thread_buf
=
qgrad_blockwise_gemm
.
GetCThreadBuffer
();
//
// calculate y dot ygrad
//
// clear accum buffers
y_dot_ygrad_thread_accum_buf
.
Clear
();
y_dot_ygrad_block_accum_buf
.
Clear
();
index_t
oblock_idx
=
0
;
do
{
...
...
@@ -1673,7 +1605,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
#endif
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely
acces
sed after barrier
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely
reu
sed after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr
.
Run
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
y_dot_ygrad_block_accum_buf
,
...
...
@@ -1681,7 +1613,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_dot_ygrad_thread_buf
);
#if
1
#if
0
if(hipBlockIdx_x < 4 && hipThreadIdx_x % 32 < 4)
{
printf("bid %zd tid %zd, y_m0_m1_o0_o1 = %d, %d, %d, %d\n",
...
...
@@ -1700,6 +1632,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple
(
I0
,
I0
,
I0
,
I0
),
lse_thread_buf
);
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
// Initialize dQ
qgrad_thread_buf
.
Clear
();
...
...
@@ -1715,7 +1650,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
continue
;
}
//
gemm0
//
P = Q * K^T
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
q_grid_desc_k0_m_k1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
...
...
@@ -1728,8 +1663,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
k_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
acc
_thread_buf
,
s_
blockwise_gemm
,
s_slash_p
_thread_buf
,
num_k_block_main_loop
);
// do MNK padding or upper triangular masking
...
...
@@ -1737,11 +1672,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
s_
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
s_
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
...
...
@@ -1760,7 +1695,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
auto
acc0_thread_origin
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
auto
acc0_thread_origin
=
s_
blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
...
...
@@ -1779,11 +1714,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
{
acc
_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
s_slash_p
_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc
_thread_buf
(
i
),
acc
_thread_buf
[
i
]);
acc_element_op
(
s_slash_p
_thread_buf
(
i
),
s_slash_p
_thread_buf
[
i
]);
}
});
}
...
...
@@ -1800,32 +1735,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
printf("tid %zd, S[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x,
acc
_thread_buf[I0],
acc
_thread_buf[I1],
acc
_thread_buf[I2],
acc
_thread_buf[I3]);
s_slash_p
_thread_buf[I0],
s_slash_p
_thread_buf[I1],
s_slash_p
_thread_buf[I2],
s_slash_p
_thread_buf[I3]);
}
#endif
// P_i: = softmax(S_i:)
blockwise_softmax
.
RunWithPreCalcStats
(
acc
_thread_buf
,
lse_thread_buf
);
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p
_thread_buf
,
lse_thread_buf
);
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("tid %zd, P[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x,
acc
_thread_buf[I0],
acc
_thread_buf[I1],
acc
_thread_buf[I2],
acc
_thread_buf[I3]);
s_slash_p
_thread_buf[I0],
s_slash_p
_thread_buf[I1],
s_slash_p
_thread_buf[I2],
s_slash_p
_thread_buf[I3]);
}
#endif
block_sync_lds
();
// wait for gemm1 LDS read
SubThreadBlock
<
BlockSize
>
p_thread_copy_subgroup
(
blockwise_gemm
.
GetWaveIdx
()[
I0
],
blockwise_gemm
.
GetWaveIdx
()[
I1
]);
SubThreadBlock
<
BlockSize
>
p_thread_copy_subgroup
(
s_
blockwise_gemm
.
GetWaveIdx
()[
I0
],
s_
blockwise_gemm
.
GetWaveIdx
()[
I1
]);
constexpr
index_t
num_vgrad_gemm_loop
=
MPerBlock
/
VGradGemmTile_N_O_M
::
Sum_M
;
static_assert
(
sfc_p_m0_n0_m1_n1_m2_n2
.
GetNumOfAccess
()
==
num_vgrad_gemm_loop
,
""
);
...
...
@@ -1864,7 +1799,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_thread_copy_vgpr_to_lds
.
Run
(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
p_nd_idx
[
I0
],
p_nd_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
acc
_thread_buf
,
s_slash_p
_thread_buf
,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
p_block_buf
);
}
...
...
@@ -1936,7 +1871,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_blockwise_gemm
,
pgrad_thread_buf
,
num_o_block_main_loop
);
#if
1
#if
0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("outer j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f\n",
...
...
@@ -1963,10 +1898,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I1
];
// dS and P has same thread buf layout
sgrad_thread_buf
(
i
)
=
acc
_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
s_slash_p
_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
});
#if
1
#if
0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("outer j loop idx %d, tid %zd, dS[0:3] = %f, %f, %f, %f\n",
...
...
@@ -2016,7 +1951,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
#if
1
#if
0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("inner j loop idx %d, tid %zd, dS downcast[0:3] = %f, %f, %f, %f\n",
...
...
@@ -2079,7 +2014,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO ANT:
// shuffle dQ and write
#if
1
#if
0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("tid %zd, dQ[0:3] = %f, %f, %f, %f\n",
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment