Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
971220d8
"test/vscode:/vscode.git/clone" did not exist on "c51f20f9c5fd15e60702d3e4cbfe5d68c16a487a"
Commit
971220d8
authored
Sep 16, 2021
by
ltqin
Browse files
gridwise gemm data copy and blockgwise gemm
parent
a52e5a92
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
447 additions
and
435 deletions
+447
-435
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
+447
-435
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
View file @
971220d8
...
@@ -264,23 +264,24 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -264,23 +264,24 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
decltype
(
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
CGMNGridDesc
{}));
decltype
(
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
CGMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CGMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CGMNGridDesc
{}));
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
AGK0MK1GridDesc
&
a_
g_
k0_m_k1_grid_desc
,
const
BGK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
BGK0NK1GridDesc
&
b_
g_
k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_
g_
m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
{
{
/*
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
p_a_grid
,
a_
g_
k0_m_k1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
p_b_grid
,
b_
g_
k0_n_k1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
p_c_grid
,
c_
g_
m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I
0
);
const
auto
K0
=
a_
g_
k0_m_k1_grid_desc
.
GetLength
(
I
1
);
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -288,10 +289,11 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -288,10 +289,11 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane(block_work_idx[I
0
] * MPerBlock);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I
1
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
NPerBlock
);
const
index_t
g_idx
=
block_work_idx
[
I0
];
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -301,60 +303,68 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -301,60 +303,68 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
a_g_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
b_g_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence<KPerBlock, MPerBlock, K1>,
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadSliceLengths_
G_
K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_
G_
K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype(a_k0_m_k1_grid_desc),
decltype
(
a_
g_
k0_m_k1_grid_desc
),
decltype(a_k0_m_k1_block_desc),
decltype
(
a_
g_
k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence<
1
,
0
,
2
>,
Sequence
<
0
,
2
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true>(a_k0_m_k1_grid_desc,
true
>
(
make_multi_index(0, m_block_data_idx_on_grid,
a_g_k0_m_k1_grid_desc
,
0), a_k0_m_k1_block_desc, make_multi_index(0, 0, 0));
make_multi_index
(
g_idx
,
0
,
m_block_data_idx_on_grid
,
0
),
a_g_k0_m_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence<KPerBlock, NPerBlock, K1>,
Sequence
<
1
,
KPerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadSliceLengths_
G_
K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_
G_
K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype(b_k0_n_k1_grid_desc),
decltype
(
b_
g_
k0_n_k1_grid_desc
),
decltype(b_k0_n_k1_block_desc),
decltype
(
b_
g_
k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence<
1
,
0
,
2
>,
Sequence
<
0
,
2
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true>(b_k0_n_k1_grid_desc,
true
>
(
make_multi_index(0, n_block_data_idx_on_grid,
b_g_k0_n_k1_grid_desc
,
0), b_k0_n_k1_block_desc, make_multi_index(0, 0, 0));
make_multi_index
(
g_idx
,
0
,
n_block_data_idx_on_grid
,
0
),
b_g_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
@@ -376,8 +386,7 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -376,8 +386,7 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
K1
>
{};
K1
>
{};
constexpr
auto
c_mr_nr_blk_desc
=
constexpr
auto
c_mr_nr_blk_desc
=
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
Number<NRepeat>{}));
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2ThreadDescriptor
();
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2ThreadDescriptor
();
...
@@ -391,38 +400,39 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -391,38 +400,39 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(),
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
max_lds_align);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
,
0
,
0
);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise
// hack to control index calculation when iterating over A and B matrix for threadwise
copy
copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto
constexpr
auto
a_
g_
k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
b
_k0_n_k1_grid_step_hacks = BGridStepHacks{};
constexpr
auto
b_g
_k0_n_k1_grid_step_hacks
=
BGridStepHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack =
constexpr
auto
a_g_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks{}; constexpr auto
AGridMoveSliceWindowStepHacks
{};
b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
constexpr
auto
b_g_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
p_a_block
,
a_
g_
k0_m_k1_block_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
p_b_block
,
b_
g_
k0_n_k1_block_desc
.
GetElementSpaceSize
());
// preload data into LDS
// preload data into LDS
{
{
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf,
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf,
a_g_k0_m_k1_grid_desc
,
a_grid_buf
,
a_g_k0_m_k1_grid_step_hacks
);
b_k0_n_k1_grid_step_hacks);
b_blockwise_copy
.
RunRead
(
b_g_k0_n_k1_grid_desc
,
b_grid_buf
,
b_g_k0_n_k1_grid_step_hacks
);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
a_blockwise_copy
.
RunWrite
(
a_
g_
k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
b_blockwise_copy
.
RunWrite
(
b_
g_
k0_n_k1_block_desc
,
b_block_buf
);
}
}
// main body
// main body
...
@@ -430,27 +440,27 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -430,27 +440,27 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
do
do
{
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_
g_
k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
a
_k0_m_k1_grid_move_slice_window_step_hack);
a_g
_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_
g_
k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
b
_k0_n_k1_grid_move_slice_window_step_hack);
b_g
_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy.RunRead(
a_k0_m_k1_grid_desc, a_grid_buf,
a_blockwise_copy
.
RunRead
(
a
_k0_m_k1_grid_step_hacks);
a_g_k0_m_k1_grid_desc
,
a_grid_buf
,
a_g
_k0_m_k1_grid_step_hacks
);
block_sync_lds
();
block_sync_lds
();
b_blockwise_copy.RunRead(
b_k0_n_k1_grid_desc, b_grid_buf,
b_blockwise_copy
.
RunRead
(
b
_k0_n_k1_grid_step_hacks);
b_g_k0_n_k1_grid_desc
,
b_grid_buf
,
b_g
_k0_n_k1_grid_step_hacks
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
a_blockwise_copy
.
RunWrite
(
a_
g_
k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
b_blockwise_copy
.
RunWrite
(
b_
g_
k0_n_k1_block_desc
,
b_block_buf
);
k_block_data_begin
+=
KPerBlock
;
k_block_data_begin
+=
KPerBlock
;
}
while
(
k_block_data_begin
<
(
K0
-
KPerBlock
));
}
while
(
k_block_data_begin
<
(
K0
-
KPerBlock
));
...
@@ -462,19 +472,21 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -462,19 +472,21 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
// output: register to global memory
/*
// output: register to global memory
{
{
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
constexpr auto M2 =
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); constexpr auto M3 =
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); constexpr auto M4 =
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0,
I0);
const index_t m_thread_data_on_grid =
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
...
@@ -482,16 +494,16 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -482,16 +494,16 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
const index_t n_thread_data_on_grid =
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
= CGridStepHacks{};
auto c_thread_copy =
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
FloatC,
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
Sequence<I1, I1, I1, I1, M2, I1, M4, I1>,
Sequence<I1, I1, I1, I1,
CThreadTransferSrcDstAccessOrder,
M2, I1, M4, I1>, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
CGlobalMemoryDataOperation,
1,
1,
...
@@ -502,81 +514,81 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -502,81 +514,81 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
0,
0,
0,
0,
m_thread_data_on_grid / (M3 * M4),
m_thread_data_on_grid / (M3 * M4),
m_thread_data_on_grid % (M3 * M4) / M4,
m_thread_data_on_grid % (M3 * M4) /
m_thread_data_on_grid % M4,
M4, m_thread_data_on_grid % M4, n_thread_data_on_grid)};
n_thread_data_on_grid)};
auto init_copy = [&](auto c_thread_idx_) {
auto init_copy = [&](auto c_thread_idx_) {
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0, I0, I0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
return c_thread_idx_;
return c_thread_idx_;
};
};
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(1, 0,
0, 0, 0, 0, 0, 0);
constexpr auto mrepeat_step_plus = make_multi_index(1, 0,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0, I0, I0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
};
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1,
0, 0, 0, 0, 0, 0);
constexpr auto nrepeat_step_plus = make_multi_index(0, 1,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_plus);
nrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0, I0, I0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
};
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0,
0, 0, 0, 0, 0, 0);
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0, I0, I0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
};
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1,
0, 0, 0, 0, 0, 0);
constexpr auto nrepeat_step_minus = make_multi_index(0, -1,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_minus);
nrepeat_step_minus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0, I0, I0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
};
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2)
&& NRepeat == 2) or (MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
or (MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or (MRepeat == 1 &&
(MRepeat == 2
&& NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or (MRepeat == 1 &&
NRepeat == 1), "wrong");
NRepeat == 1), "wrong");
if constexpr(MRepeat == 4 && NRepeat == 4)
if constexpr(MRepeat == 4 && NRepeat == 4)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment