Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
036c5234
Commit
036c5234
authored
May 14, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2
parents
22995e9a
7843a8a7
Changes
207
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3093 additions
and
189 deletions
+3093
-189
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+9
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+2
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+122
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+76
-61
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
...tion/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
+2493
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+70
-68
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
+57
-16
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+1
-1
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+1
-1
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+1
-1
include/ck/utility/debug.hpp
include/ck/utility/debug.hpp
+8
-1
include/ck/utility/env.hpp
include/ck/utility/env.hpp
+185
-0
include/ck/utility/flush_icache.hpp
include/ck/utility/flush_icache.hpp
+30
-0
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+2
-1
include/ck/utility/sequence.hpp
include/ck/utility/sequence.hpp
+14
-1
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+4
-17
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+1
-1
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+5
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+11
-0
No files found.
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
036c5234
...
...
@@ -151,7 +151,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
...
...
@@ -260,7 +260,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
};
// Grouped Rows of column-vectors WGP mapping
// Optimized for
MI300
-like multipe-die chip
// Optimized for
gfx94x
-like multipe-die chip
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
...
...
@@ -275,7 +275,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
...
...
@@ -428,7 +428,7 @@ struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
...
...
@@ -900,6 +900,11 @@ struct OffsettedBlockToCTileMap
return
block_to_ctile_map_
.
CalculateGridSize
(
c_grid_desc_m_n
);
}
__host__
__device__
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
M
,
N
);
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
block_start_
;
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
036c5234
...
...
@@ -594,11 +594,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
);
},
Number
<
NumATensor
>
{});
#if 0
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
AsDataType
,
...
...
@@ -616,7 +611,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
uniform_sequence_gen_t
<
NumATensor
,
false
>
,
uniform_sequence_gen_t
<
NumATensor
,
AThreadTransferSrcResetCoordinateAfterRun
>
,
Sequence
<
true
>>
{
as_grid_desc_ak0_m_ak1
,
idx_as_block_begin
,
tie
(
a_block_desc_ak0_m_ak1
),
...
...
@@ -627,11 +622,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
);
},
Number
<
NumBTensor
>
{});
#if 0
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
BsDataType
,
...
...
@@ -649,7 +639,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
uniform_sequence_gen_t
<
NumBTensor
,
false
>
,
uniform_sequence_gen_t
<
NumBTensor
,
BThreadTransferSrcResetCoordinateAfterRun
>
,
Sequence
<
true
>>
{
bs_grid_desc_bk0_n_bk1
,
idx_bs_block_begin
,
tie
(
b_block_desc_bk0_n_bk1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
036c5234
...
...
@@ -257,7 +257,70 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
e_grid_desc_m_n
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
>
__host__
__device__
static
bool
CheckTensorTransfersValidity
(
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
)
{
// Check if the vector dim is K1 or M|N
const
auto
A_vector_dim_size
=
ABlockTransferSrcVectorDim
==
2
?
KRaw
:
MRaw
;
const
auto
B_vector_dim_size
=
BBlockTransferSrcVectorDim
==
2
?
KRaw
:
NRaw
;
const
auto
E_vector_dim_size
=
NRaw
;
// check vector load for A tensor
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
if
(
!
(
A_vector_dim_size
==
KRaw
&&
A_vector_dim_size
%
ABlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
if
(
!
(
A_vector_dim_size
==
MRaw
&&
A_vector_dim_size
%
ABlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
{
return
false
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
if
(
!
(
B_vector_dim_size
==
NRaw
&&
B_vector_dim_size
%
BBlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
if
(
!
(
B_vector_dim_size
==
KRaw
&&
B_vector_dim_size
%
BBlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
{
return
false
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout
>
)
{
if
(
!
(
E_vector_dim_size
==
NRaw
&&
E_vector_dim_size
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
return
false
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout
>
)
{
if
(
!
(
E_vector_dim_size
==
NRaw
&&
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
1
))
return
false
;
}
else
{
return
false
;
}
return
true
;
}
template
<
typename
AGridDesc_M_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
...
...
@@ -267,7 +330,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
)
[[
maybe_unused
]]
const
Block2ETileMap
&
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
...
@@ -285,7 +348,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
{
return
false
;
}
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
...
@@ -306,7 +368,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// check gridwise gemm pipeline
const
auto
num_k_loop
=
AK
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
...
...
@@ -938,6 +999,63 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
}
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_MK
,
typename
BGridDesc_NK
,
typename
DsGridDesc_MN
,
typename
EGridDesc_MN
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
DsGridPointer
p_ds_grid
,
void
*
__restrict__
p_e_grid_
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_MK
&
a_grid_desc_m_k
,
const
BGridDesc_NK
&
b_grid_desc_n_k
,
const
DsGridDesc_MN
&
ds_grid_desc_m_n
,
const
EGridDesc_MN
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
const
auto
p_e_grid
=
reinterpret_cast
<
EDataType
*>
(
p_e_grid_
);
// tensor descriptors for block/thread-wise copy
const
auto
a_grid_desc_ak0_m_ak1
=
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_MN
{}))
>
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
ds_grid_desc_mblock_mperblock_nblock_nperblock
(
j
)
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
j
]);
});
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
036c5234
...
...
@@ -4,6 +4,7 @@
#pragma once
#include <iostream>
#include <ostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
036c5234
...
...
@@ -935,12 +935,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -952,12 +952,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -971,12 +971,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -995,13 +995,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -1009,13 +1009,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -1024,13 +1024,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -1038,13 +1038,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -1053,14 +1053,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
...
...
@@ -1068,14 +1069,28 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
#endif // DEBUG_LOG
if
constexpr
(
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
" KBatch: "
<<
karg
.
KBatch
<<
" > 1 is not support yet"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
if
(
karg
.
KBatch
>
1
)
{
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
0 → 100644
View file @
036c5234
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
// kernel function Blockers:
// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
// two lds chunks.
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
// buffer when we declare __shared__ inside blkgemmpipe
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_as_grid
,
karg
.
p_bs_grid
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
,
p_shared
,
karg
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_as_grid
,
karg
.
p_bs_grid
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AsDataType
,
typename
BsDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v4
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
GridwiseGemm_xdl_cshuffle_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
using
LDSTypeA
=
ComputeTypeA
;
using
LDSTypeB
=
ComputeTypeB
;
// K1 should be Number<...>
static
constexpr
auto
AK0Number
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0Number
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
static
constexpr
index_t
NumBTensor
=
BsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
MakeAsGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
ADataType_
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsDataType
>>
;
return
static_cast
<
const
ADataType_
*>
(
nullptr
);
},
Number
<
NumATensor
>
{});
}
static
constexpr
auto
MakeBsGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
BDataType_
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsDataType
>>
;
return
static_cast
<
const
BDataType_
*>
(
nullptr
);
},
Number
<
NumBTensor
>
{});
}
static
constexpr
auto
MakeDsGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
static_cast
<
const
DDataType
*>
(
nullptr
);
},
Number
<
NumDTensor
>
{});
}
using
AsGridPointer
=
decltype
(
MakeAsGridPointer
());
using
BsGridPointer
=
decltype
(
MakeBsGridPointer
());
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeB
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
KBatch
);
}
__host__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_least_multiple
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
static
auto
CalculateKPadded
(
index_t
K
)
{
return
math
::
integer_divide_ceil
(
K
,
KPerBlock
)
*
KPerBlock
;
}
__host__
static
auto
CalculateAK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
AK1Value
);
}
__host__
static
auto
CalculateBK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
BK1Value
);
}
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KPerBlock
;
}
__host__
static
auto
CalculateKRead
(
index_t
K
,
index_t
K_Batch
=
1
)
{
constexpr
auto
KReadVec
=
math
::
lcm
(
AK1Number
,
BK1Number
);
auto
K_t
=
K_Batch
*
KReadVec
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KReadVec
;
}
__host__
static
auto
CalculateMBlock
(
index_t
M
)
{
return
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNBlock
(
index_t
N
)
{
return
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
}
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor
(
const
TileDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
K1
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
2
>
{});
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
__host__
__device__
static
auto
MakeAsGridDescriptor_AK0_M_AK1
(
const
index_t
M
,
const
index_t
MPad
,
const
index_t
K
,
const
index_t
KPad
,
const
std
::
array
<
index_t
,
NumATensor
>&
StrideAs
,
const
index_t
AK0
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeAGridDescriptor_AK0_M_AK1
(
M
,
MPad
,
K
,
KPad
,
StrideAs
[
i
],
AK0
);
},
Number
<
NumATensor
>
{});
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
StrideB
,
I1
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
N
,
NPad
-
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
__host__
__device__
static
auto
MakeBsGridDescriptor_BK0_N_BK1
(
const
index_t
K
,
const
index_t
KPad
,
const
index_t
N
,
const
index_t
NPad
,
const
std
::
array
<
index_t
,
NumBTensor
>&
StrideBs
,
const
index_t
BK0
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeBGridDescriptor_BK0_N_BK1
(
K
,
KPad
,
N
,
NPad
,
StrideBs
[
i
],
BK0
);
},
Number
<
NumBTensor
>
{});
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeAMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeCGridDescriptor_M_N
(
M
,
MPad
,
N
,
NPad
,
StrideDs
[
i
]);
},
Number
<
NumDTensor
>
{});
}
struct
Problem
{
__host__
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
std
::
array
<
index_t
,
NumATensor
>
StrideAs_
,
std
::
array
<
index_t
,
NumBTensor
>
StrideBs_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideC_
,
index_t
KBatch_
)
:
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideAs
{
StrideAs_
},
StrideBs
{
StrideBs_
},
StrideDs
{
StrideDs_
},
StrideC
{
StrideC_
},
KBatch
{
KBatch_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)},
KRead
{
CalculateKRead
(
K_
,
KBatch_
)},
KPadded
{
CalculateKPadded
(
K_
,
KBatch_
)},
AK0
{
CalculateAK0Padded
(
K_
,
KBatch_
)},
BK0
{
CalculateBK0Padded
(
K_
,
KBatch_
)},
MBlock
{
CalculateMBlock
(
M_
)},
NBlock
{
CalculateNBlock
(
N_
)}
{
}
__host__
void
Print
()
const
{
std
::
cout
<<
"problem {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KRead:"
<<
KRead
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
", "
<<
"BK0:"
<<
BK0
<<
", "
<<
"MBlock: "
<<
MBlock
<<
", "
<<
"NBlock: "
<<
NBlock
<<
"}"
<<
std
::
endl
;
}
index_t
M
;
index_t
N
;
index_t
K
;
std
::
array
<
index_t
,
NumATensor
>
StrideAs
;
std
::
array
<
index_t
,
NumBTensor
>
StrideBs
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideC
;
index_t
KBatch
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KRead
;
index_t
KPadded
;
index_t
AK0
;
index_t
BK0
;
index_t
MBlock
;
index_t
NBlock
;
};
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
Problem
{
__host__
Argument
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as_grid_
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
void
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
std
::
array
<
index_t
,
NumATensor
>
StrideAs_
,
std
::
array
<
index_t
,
NumBTensor
>
StrideBs_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideC_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
:
Problem
{
M_
,
N_
,
K_
,
StrideAs_
,
StrideBs_
,
StrideDs_
,
StrideC_
,
k_batch_
},
p_as_grid
{},
p_bs_grid
{},
p_ds_grid
{},
p_c_grid
{
static_cast
<
CDataType
*>
(
p_c_grid_
)},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
c_element_op
{
c_element_op_
}
{
// populate pointer, desc for As
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
using
ADataType_
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsDataType
>>
;
// A pointer
p_as_grid
(
i
)
=
static_cast
<
const
ADataType_
*>
(
p_as_grid_
[
i
]);
});
// populate pointer, desc for Bs
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
using
BDataType_
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsDataType
>>
;
// B pointer
p_bs_grid
(
i
)
=
static_cast
<
const
BDataType_
*>
(
p_bs_grid_
[
i
]);
});
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType_
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType_
*>
(
p_ds_grid_
[
i
]);
});
}
AsGridPointer
p_as_grid
;
BsGridPointer
p_bs_grid
;
DsGridPointer
p_ds_grid
;
CDataType
*
p_c_grid
;
const
AElementwiseOperation
a_element_op
;
const
BElementwiseOperation
b_element_op
;
const
CElementwiseOperation
c_element_op
;
};
struct
SplitKBatchOffset
{
__device__
SplitKBatchOffset
(
Argument
&
karg
)
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
M
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
N
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
}
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
))
{
karg
.
K
=
karg
.
KRead
;
}
else
{
karg
.
K
=
karg
.
K
-
karg
.
KRead
*
(
karg
.
KBatch
-
1
);
}
}
index_t
a_k_split_offset
;
index_t
b_k_split_offset
;
};
#if 0
struct SplitKBatchOffsetMultiABD
{
__device__ SplitKBatchOffsetMultiABD(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
Argument& karg)
{
static_for<0, NumATensor, 1>{}([&](auto i) {
using ALayout_ = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout_>)
{
as_k_split_offset[i] = blockIdx.z * karg.KRead;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout_>)
{
as_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideAs[i];
}
p_as_grid_(i) = p_as_grid[i] + as_k_split_offset[i];
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BLayout_ = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout_>)
{
bs_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideBs[i];
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout_>)
{
bs_k_split_offset[i] = blockIdx.z * karg.KRead;
}
p_bs_grid_(i) = p_bs_grid[i] + bs_k_split_offset[i];
});
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
}
else
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
}
AsGridPointer p_as_grid_;
BsGridPointer p_bs_grid_;
std::array<index_t, NumATensor> as_k_split_offset;
std::array<index_t, NumBTensor> bs_k_split_offset;
};
#endif
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
+
ABlockLdsExtraM
>
{},
I1
));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
);
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
*
Number
<
MLdsLayer
>
{},
Number
<
MPerBlock
/
MLdsLayer
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
*
MLdsLayer
>
{},
I1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
constexpr
auto
a_lds_block_desc_ak0_mldslayer_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0Number
,
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
Number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}));
constexpr
auto
a_lds_block_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_ak0_mldslayer_m_ak1
,
make_tuple
(
make_pass_through_transform
(
AK0Number
),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_lds_block_desc_ak0_m_ak1
;
}
else
// ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
M0
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I1
);
constexpr
auto
M1
=
MPerBlock
/
M0
;
constexpr
auto
KThreadWrite
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I0
);
constexpr
auto
K0PerThreadWrite
=
AK0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
MPerXdl
;
constexpr
auto
K0PerThreadRead
=
AK0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
AK1Number
*
M0
*
sizeof
(
LDSTypeA
)
>
128
)
?
1
:
128
/
(
AK1Number
*
M0
*
sizeof
(
LDSTypeA
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mpair<=n0
constexpr
auto
mpair
=
(
AK1Number
*
MPerXdl
*
sizeof
(
LDSTypeA
)
>
128
)
?
1
:
((
128
/
(
AK1Number
*
MPerXdl
*
sizeof
(
LDSTypeA
)))
>
M0
?
M0
:
128
/
(
AK1Number
*
MPerXdl
*
sizeof
(
LDSTypeA
)));
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
K0PerThreadWrite
>
{},
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{},
Number
<
mpair
>
{},
AK1Number
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
a_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
M1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
kfold
>
{},
Number
<
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
0
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
a_lds_block_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
kfold
>
{},
Number
<
K0PerThreadWrite
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
M0
/
mpair
>
{},
Number
<
mpair
>
{},
Number
<
M1
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
,
1
,
4
,
2
>
{},
Sequence
<
5
,
6
,
3
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_lds_block_desc_ak0_m_ak1
;
}
}
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
KPerBlock
+
BBlockLdsExtraN
>
{},
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeB
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeB
);
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
*
Number
<
NLdsLayer
>
{},
Number
<
NPerBlock
/
NLdsLayer
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
KPerBlock
*
NLdsLayer
>
{},
I1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
BK0Number
*
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_bk0_nldslayer_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0Number
,
Number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
Number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_bk0_nldslayer_n_bk1
,
make_tuple
(
make_pass_through_transform
(
BK0Number
),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_lds_block_desc_bk0_n_bk1
;
}
else
// RowMajor B
{
constexpr
auto
N0
=
BBlockTransferThreadClusterLengths_BK0_N_BK1
{}.
At
(
I1
);
constexpr
auto
N1
=
NPerBlock
/
N0
;
constexpr
auto
KThreadWrite
=
BBlockTransferThreadClusterLengths_BK0_N_BK1
{}.
At
(
I0
);
constexpr
auto
K0PerThreadWrite
=
BK0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
NPerXdl
;
constexpr
auto
K0PerThreadRead
=
BK0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
BK1Number
*
N0
*
sizeof
(
LDSTypeB
)
>
128
)
?
1
:
128
/
(
BK1Number
*
N0
*
sizeof
(
LDSTypeB
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=n0
constexpr
auto
npair
=
(
BK1Number
*
NPerXdl
*
sizeof
(
LDSTypeB
)
>
128
)
?
1
:
((
128
/
(
BK1Number
*
NPerXdl
*
sizeof
(
LDSTypeB
)))
>
N0
?
N0
:
128
/
(
BK1Number
*
NPerXdl
*
sizeof
(
LDSTypeB
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
K0PerThreadWrite
>
{},
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{},
Number
<
npair
>
{},
BK1Number
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
kfold
>
{},
Number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
0
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
b_lds_block_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
kfold
>
{},
Number
<
K0PerThreadWrite
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
N0
/
npair
>
{},
Number
<
npair
>
{},
Number
<
N1
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
,
1
,
4
,
2
>
{},
Sequence
<
5
,
6
,
3
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_lds_block_desc_bk0_n_bk1
;
}
}
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
using
BlockwiseGemmPipe
=
remove_cvref_t
<
decltype
(
BlockGemmPipeline_Selector
<
BlkGemmPipelineVer
,
BlkGemmPipeSched
,
BlockSize
,
LDSTypeA
,
LDSTypeB
,
ComputeTypeA
,
AccDataType
,
decltype
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()),
decltype
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()),
decltype
(
MakeAMmaTileDescriptor_M0_M1_M2_K
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
())),
decltype
(
MakeBMmaTileDescriptor_N0_N1_N2_K
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
())),
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
())
>
;
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
+
b_block_space_size_aligned
*
sizeof
(
LDSTypeB
)),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
else
{
constexpr
auto
KReadVec
=
math
::
lcm
(
AK1Number
,
BK1Number
);
auto
K_t
=
karg
.
KBatch
*
KReadVec
;
auto
KReadPadSplited
=
math
::
integer_divide_ceil
(
karg
.
K
,
K_t
)
*
KReadVec
;
if
((
KReadPadSplited
*
(
karg
.
KBatch
-
1
))
>=
karg
.
K
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
else
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
else
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
else
{
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
karg
.
AK0
/
(
KPerBlock
/
AK1Value
);
if
constexpr
(
BlkGemmPipelineVer
!=
BlockGemmPipelineVersion
::
v1
)
{
if
(
num_k_loop
<=
BlockwiseGemmPipe
::
PrefetchStages
)
{
return
false
;
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
BlockwiseGemmPipe
::
BlockHasHotloop
(
num_loop
);
}
__host__
static
constexpr
TailNumber
CalculateKBlockLoopTailNum
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
BlockwiseGemmPipe
::
BlockLoopTailNum
(
num_loop
);
}
template
<
typename
CGridDesc
>
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
template
<
typename
DsGridDesc
>
__device__
static
constexpr
auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDesc
&
ds_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
i
],
MBlock
,
NBlock
);
},
Number
<
NumDTensor
>
{});
}
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
(
0
,
0
,
0
,
0
,
{}))
>
;
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using
Block2CTileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run
(
AsGridPointer
&
p_as_grid
,
BsGridPointer
&
p_bs_grid
,
DsGridPointer
&
p_ds_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
const
Problem
&
problem
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
{
// std::array<index_t, NumATensor> StrideAs = {problem.StrideA};
// std::array<index_t, NumBTensor> StrideBs = {problem.StrideB};
// AsGridPointer p_as_grid;
// BsGridPointer p_bs_grid;
// DsGridPointer p_ds_grid;
// const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
// problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
// const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
// problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const
auto
as_grid_desc_ak0_m_ak1
=
MakeAsGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideAs
,
problem
.
AK0
);
const
auto
bs_grid_desc_bk0_n_bk1
=
MakeBsGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideBs
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
ds_grid_desc_m_n
=
MakeDsGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideDs
);
#if 0
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_m_n(j) = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs[j]);
});
#endif
const
auto
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
// const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
// const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_bs_grid[I0], b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const
auto
as_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_as_grid
[
i
],
as_grid_desc_ak0_m_ak1
[
i
].
GetElementSpaceSize
());
},
Number
<
NumATensor
>
{});
const
auto
bs_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bs_grid
[
i
],
bs_grid_desc_bk0_n_bk1
[
i
].
GetElementSpaceSize
());
},
Number
<
NumBTensor
>
{});
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMap
{
problem
.
M
,
problem
.
N
,
4
};
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
#if 0
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const
auto
idx_as_block_begin
=
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
);
},
Number
<
NumATensor
>
{});
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
AsDataType
,
Tuple
<
LDSTypeA
>
,
decltype
(
as_grid_desc_ak0_m_ak1
),
decltype
(
tie
(
a_block_desc_ak0_m_ak1
)),
AElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
uniform_sequence_gen_t
<
NumATensor
,
AThreadTransferSrcResetCoordinateAfterRun
>
,
Sequence
<
true
>
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
{
as_grid_desc_ak0_m_ak1
,
idx_as_block_begin
,
tie
(
a_block_desc_ak0_m_ak1
),
make_tuple
(
make_multi_index
(
0
,
0
,
0
)),
a_element_op
};
#endif
#if 0
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const
auto
idx_bs_block_begin
=
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
);
},
Number
<
NumBTensor
>
{});
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
BsDataType
,
Tuple
<
LDSTypeB
>
,
decltype
(
bs_grid_desc_bk0_n_bk1
),
decltype
(
tie
(
b_block_desc_bk0_n_bk1
)),
BElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
uniform_sequence_gen_t
<
NumBTensor
,
BThreadTransferSrcResetCoordinateAfterRun
>
,
Sequence
<
true
>
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
{
bs_grid_desc_bk0_n_bk1
,
idx_bs_block_begin
,
tie
(
b_block_desc_bk0_n_bk1
),
make_tuple
(
make_multi_index
(
0
,
0
,
0
)),
b_element_op
};
#endif
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
// Cast after lds
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeA
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeB
*>
(
p_shared
)
+
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
/
sizeof
(
LDSTypeB
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
// Blockwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
as_grid_desc_ak0_m_ak1
[
I0
].
GetLength
(
I0
)
*
as_grid_desc_ak0_m_ak1
[
I0
].
GetLength
(
I2
))
/
KPerBlock
);
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
as_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
as_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
bs_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
bs_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
#if 0
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
#else
using
EDataType
=
CDataType
;
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_desc_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_buf_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_buf
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_buf
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c_ds_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
);
},
Number
<
NumDTensor
>
{}));
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
using
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
=
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
const
auto
CDEShuffleBlockTransferScalarPerVector_NPerBlock
=
CShuffleBlockTransferScalarPerVector_NPerBlock
;
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
Tuple
<
EDataType
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename SrcDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DstDimAccessOrder,
3
,
// index_t SrcVectorDim,
3
,
// index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
)),
c_element_op
};
#endif
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
#if 0
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
#else
// space filling curve for shuffled blockwise C/D/E
constexpr
auto
sfc_cde_block
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
#endif
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
#if 0
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
#else
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
c_grid_buf
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
});
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
}
#endif
});
}
}
#if 1
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
AsGridPointer
&
p_as_grid
,
BsGridPointer
&
p_bs_grid
,
DsGridPointer
&
p_ds_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
const
Problem
&
problem
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
{
// const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
// problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
// const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
// problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const
auto
as_grid_desc_ak0_m_ak1
=
MakeAsGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideAs
,
problem
.
AK0
);
const
auto
bs_grid_desc_bk0_n_bk1
=
MakeBsGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideBs
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
ds_grid_desc_m_n
=
MakeDsGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideDs
);
const
auto
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
// const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
// const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const
auto
as_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_as_grid
[
i
],
as_grid_desc_ak0_m_ak1
[
i
].
GetElementSpaceSize
());
},
Number
<
NumATensor
>
{});
const
auto
bs_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bs_grid
[
i
],
bs_grid_desc_bk0_n_bk1
[
i
].
GetElementSpaceSize
());
},
Number
<
NumBTensor
>
{});
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMap
{
problem
.
M
,
problem
.
N
,
4
};
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
#if 0
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const
auto
idx_as_block_begin
=
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
);
},
Number
<
NumATensor
>
{});
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
AsDataType
,
Tuple
<
LDSTypeA
>
,
decltype
(
as_grid_desc_ak0_m_ak1
),
decltype
(
tie
(
a_block_desc_ak0_m_ak1
)),
AElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
uniform_sequence_gen_t
<
NumATensor
,
false
>
,
Sequence
<
true
>
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
{
as_grid_desc_ak0_m_ak1
,
idx_as_block_begin
,
tie
(
a_block_desc_ak0_m_ak1
),
make_tuple
(
make_multi_index
(
0
,
0
,
0
)),
a_element_op
};
#endif
#if 0
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const
auto
idx_bs_block_begin
=
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
);
},
Number
<
NumBTensor
>
{});
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
BsDataType
,
Tuple
<
LDSTypeB
>
,
decltype
(
bs_grid_desc_bk0_n_bk1
),
decltype
(
tie
(
b_block_desc_bk0_n_bk1
)),
BElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
uniform_sequence_gen_t
<
NumBTensor
,
false
>
,
Sequence
<
true
>
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
{
bs_grid_desc_bk0_n_bk1
,
idx_bs_block_begin
,
tie
(
b_block_desc_bk0_n_bk1
),
make_tuple
(
make_multi_index
(
0
,
0
,
0
)),
b_element_op
};
#endif
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeA
*>
(
p_shared_0
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeB
*>
(
p_shared_0
)
+
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
/
sizeof
(
LDSTypeB
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeA
*>
(
p_shared_1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeB
*>
(
p_shared_1
)
+
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
/
sizeof
(
LDSTypeB
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
auto
b_block_bufs
=
make_tuple
(
b_block_buf_ping
,
b_block_buf_pong
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
// Blockwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
as_grid_desc_ak0_m_ak1
[
I0
].
GetLength
(
I0
)
*
as_grid_desc_ak0_m_ak1
[
I0
].
GetLength
(
I2
))
/
KPerBlock
);
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
as_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
as_grid_buf
,
a_block_bufs
,
a_block_slice_copy_step
,
bs_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
bs_grid_buf
,
b_block_bufs
,
b_block_slice_copy_step
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared_0
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
#if 0
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
#else
using
EDataType
=
CDataType
;
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_desc_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_buf_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_buf
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_buf
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c_ds_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
);
},
Number
<
NumDTensor
>
{}));
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
using
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
=
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
const
auto
CDEShuffleBlockTransferScalarPerVector_NPerBlock
=
CShuffleBlockTransferScalarPerVector_NPerBlock
;
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
Tuple
<
EDataType
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename SrcDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DstDimAccessOrder,
3
,
// index_t SrcVectorDim,
3
,
// index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
c_element_op
};
#endif
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
#if 1
// space filling curve for shuffled blockwise C/D/E
constexpr
auto
sfc_cde_block
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
#endif
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
#if 0
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
#else
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
c_grid_buf
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
});
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
}
#endif
});
}
}
#endif
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
036c5234
...
...
@@ -446,12 +446,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -463,12 +463,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -482,12 +482,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
K_t
=
karg
.
k_batch
*
K0PerBlock
*
K1
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -496,13 +496,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -510,13 +510,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -525,13 +525,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -539,13 +539,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
}
}
...
...
@@ -554,14 +554,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
if
(
karg
.
N
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
"CBlockTransferScalarPerVector_NWaveNPerXDL ("
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
...
...
@@ -569,14 +569,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
if
(
karg
.
M
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
"CBlockTransferScalarPerVector_NWaveNPerXDL ("
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
...
...
@@ -584,12 +584,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
auto
num_k_loop
=
karg
.
K0Padded
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"The number of k loops ("
<<
num_k_loop
<<
") value is not supported by GridwiseGemm Pipeline."
<<
" K0Padded: "
<<
karg
.
K0Padded
<<
", K0PerBlock: "
<<
K0PerBlock
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
View file @
036c5234
...
...
@@ -42,7 +42,8 @@ template <typename SrcDatas,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
typename
SrcResetCoordinateAfterRunFlags
,
// Sequence<bool ...>
typename
DstResetCoordinateAfterRunFlags
>
// Sequence<bool ...>
typename
DstResetCoordinateAfterRunFlags
,
// Sequence<bool ...>
index_t
NumThreadScratch
=
1
>
struct
ThreadwiseTensorSliceTransfer_v7r2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -139,14 +140,19 @@ struct ThreadwiseTensorSliceTransfer_v7r2
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template
<
typename
SrcBuffers
,
index_t
ThreadScratchId
=
0
,
enable_if_t
<
SrcDescs
::
Size
()
==
SrcBuffers
::
Size
(),
bool
>
=
false
>
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
)
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
// loop over space-filling curve
static_for
<
0
,
src_num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
auto
src_vectors
=
generate_vectors
<
SrcDatas
,
SrcScalarPerVector
>
();
auto
elm_vectors
=
generate_vectors
<
DstDatas
,
SrcScalarPerVector
>
();
bool
oob_val
=
true
;
// copy data from src_bufs into src_vectors
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
using
src_vector_t
=
typename
remove_cvref_t
<
decltype
(
src_vectors
[
i
])
>::
type
;
...
...
@@ -155,9 +161,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_descs
[
i
],
src_coords_
[
i
]);
oob_val
=
oob_val
&
is_src_valid
;
src_vectors
(
i
).
template
AsType
<
src_vector_t
>()(
I0
)
=
src_bufs
[
i
].
template
Get
<
src_vector_t
>(
src_coords_
[
i
].
GetOffset
(),
is_src_valid
);
src_bufs
[
i
].
template
Get
<
src_vector_t
>(
src_coords_
[
i
].
GetOffset
(),
true
);
});
constexpr
auto
get_elem_op_vec_len
=
[]()
{
...
...
@@ -218,7 +225,8 @@ struct ThreadwiseTensorSliceTransfer_v7r2
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
elm_vectors_tuple_
(
iAccess
)
=
elm_vectors
;
elm_vectors_tuple_
(
thread_scratch_id
)(
iAccess
)
=
elm_vectors
;
oob_vectors_tuple_
(
thread_scratch_id
)(
iAccess
)
=
oob_val
;
// move coordinate
if
constexpr
(
iAccess
.
value
!=
src_num_access
-
1
)
...
...
@@ -245,17 +253,38 @@ struct ThreadwiseTensorSliceTransfer_v7r2
});
}
__device__
void
TransposeFromElmToDst
()
#if 1
template
<
index_t
ThreadScratchId
=
0
>
__device__
void
OOBCheck
(
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
// loop over space-filling curve
static_for
<
0
,
src_num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
auto
elm_vectors
=
elm_vectors_tuple_
[
thread_scratch_id
][
iAccess
];
auto
oob_val
=
oob_vectors_tuple_
[
thread_scratch_id
][
iAccess
];
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
using
elm_vector_t
=
typename
remove_cvref_t
<
decltype
(
elm_vectors
[
i
])
>::
type
;
elm_vectors
(
i
).
template
AsType
<
elm_vector_t
>()(
I0
)
=
oob_val
?
elm_vectors
(
i
).
template
AsType
<
elm_vector_t
>()[
I0
]
:
elm_vector_t
{
0
};
});
elm_vectors_tuple_
(
thread_scratch_id
)(
iAccess
)
=
elm_vectors
;
});
}
#endif
template
<
index_t
ThreadScratchId
=
0
>
__device__
void
TransposeFromElmToDst
(
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
using
DstData
=
remove_cvref_t
<
decltype
(
DstDatas
{}[
I0
])
>
;
using
Src
ThreadScratch
=
using
Elm
ThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
SrcScalarPerVector
,
decltype
(
GetSrcThreadScratchDescriptor
()),
true
>
;
using
DstThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
...
...
@@ -263,15 +292,17 @@ struct ThreadwiseTensorSliceTransfer_v7r2
decltype
(
GetDstThreadScratchDescriptor
()),
true
>
;
Src
ThreadScratch
elm_thread_scratch_
;
Elm
ThreadScratch
elm_thread_scratch_
;
DstThreadScratch
dst_thread_scratch_
;
elm_thread_scratch_
.
data_
=
bit_cast
<
decltype
(
elm_thread_scratch_
.
data_
)
>
(
elm_vectors_tuple_
);
bit_cast
<
decltype
(
elm_thread_scratch_
.
data_
)
>
(
elm_vectors_tuple_
[
thread_scratch_id
]
);
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
((
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
f8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)
||
(
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
...
...
@@ -338,20 +369,24 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[
&
](
auto
idx
)
{
dst_thread_scratch_
(
idx
)
=
elm_thread_scratch_
[
idx
];
});
}
dst_vectors_tuple_
=
bit_cast
<
decltype
(
dst_v
ector
s_t
uple
_
)
>
(
dst_thread_scratch_
.
data_
);
dst_vectors_tuple_
(
thread_scratch_id
)
=
bit_cast
<
DstV
ector
T
uple
>
(
dst_thread_scratch_
.
data_
);
}
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template
<
typename
DstBuffers
,
index_t
ThreadScratchId
=
0
,
enable_if_t
<
DstDescs
::
Size
()
==
1
&&
DstBuffers
::
Size
()
==
1
,
bool
>
=
false
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
TransposeFromElmToDst
();
OOBCheck
(
thread_scratch_id
);
TransposeFromElmToDst
(
thread_scratch_id
);
// loop over space-filling curve
static_for
<
0
,
dst_num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
auto
dst_vectors
=
dst_vectors_tuple_
[
Number
<
iAccess
>
{}
];
auto
dst_vectors
=
dst_vectors_tuple_
[
thread_scratch_id
][
iAccess
];
// copy data from buf_vectors into dst_bufs
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
...
...
@@ -578,8 +613,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2
static
constexpr
auto
src_num_access
=
SrcSpaceFillingCurve
::
GetNumOfAccess
();
static
constexpr
auto
dst_num_access
=
DstSpaceFillingCurve
::
GetNumOfAccess
();
StaticallyIndexedArray
<
ElmVectorsType
,
src_num_access
>
elm_vectors_tuple_
;
StaticallyIndexedArray
<
DstVectorsType
,
dst_num_access
>
dst_vectors_tuple_
;
using
ElmVectorTuple
=
StaticallyIndexedArray
<
ElmVectorsType
,
src_num_access
>
;
using
DstVectorTuple
=
StaticallyIndexedArray
<
DstVectorsType
,
dst_num_access
>
;
StaticallyIndexedArray
<
ElmVectorTuple
,
NumThreadScratch
>
elm_vectors_tuple_
;
StaticallyIndexedArray
<
DstVectorTuple
,
NumThreadScratch
>
dst_vectors_tuple_
;
using
OOBVectorTuple
=
StaticallyIndexedArray
<
bool
,
src_num_access
>
;
StaticallyIndexedArray
<
OOBVectorTuple
,
NumThreadScratch
>
oob_vectors_tuple_
;
SrcCoords
src_coords_
;
DstCoords
dst_coords_
;
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
036c5234
...
...
@@ -95,7 +95,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed
i
n
Navi3x
, Will be wave mode dependent
on Navi4x
// * Fixed
o
n
gfx11
, Will be wave mode dependent
for future architectures
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
// * num_acc_vgprs_per_wave alone M direction
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
036c5234
...
...
@@ -297,7 +297,7 @@ enum struct AmdBufferCoherenceEnum
GLC
=
1
,
SLC
=
2
,
GLC_SLC
=
3
,
// gfx94
2
: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
WAVE_NT0
=
0
,
...
...
include/ck/utility/amd_xdlops.hpp
View file @
036c5234
...
...
@@ -4,7 +4,7 @@
#pragma once
namespace
ck
{
// Define the common macro for
MI300
models
// Define the common macro for
gfx94x
models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
...
...
include/ck/utility/debug.hpp
View file @
036c5234
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
...
...
@@ -79,6 +79,13 @@ __device__ void print_shared(T const* p_shared, index_t num_elements)
__syncthreads
();
}
template
<
index_t
...
Ids
>
__device__
static
bool
is_thread_local_1d_id_idx
()
{
const
auto
tid
=
get_thread_local_1d_id
();
return
((
tid
==
Ids
)
||
...);
}
}
// namespace debug
}
// namespace ck
...
...
include/ck/utility/env.hpp
0 → 100644
View file @
036c5234
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <cstring>
#include <string>
#include <string_view>
namespace
ck
{
namespace
internal
{
template
<
typename
T
>
struct
ParseEnvVal
{
};
template
<
>
struct
ParseEnvVal
<
bool
>
{
static
bool
parse_env_var_value
(
const
char
*
vp
)
{
std
::
string
value_env_str
{
vp
};
for
(
auto
&
c
:
value_env_str
)
{
if
(
std
::
isalpha
(
c
)
!=
0
)
{
c
=
std
::
tolower
(
static_cast
<
unsigned
char
>
(
c
));
}
}
if
(
value_env_str
==
"disable"
||
value_env_str
==
"disabled"
||
value_env_str
==
"0"
||
value_env_str
==
"no"
||
value_env_str
==
"off"
||
value_env_str
==
"false"
)
{
return
false
;
}
else
if
(
value_env_str
==
"enable"
||
value_env_str
==
"enabled"
||
value_env_str
==
"1"
||
value_env_str
==
"yes"
||
value_env_str
==
"on"
||
value_env_str
==
"true"
)
{
return
true
;
}
else
{
throw
std
::
runtime_error
(
"Invalid value for env variable"
);
}
return
false
;
// shouldn't reach here
}
};
// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default).
// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string).
template
<
>
struct
ParseEnvVal
<
uint64_t
>
{
static
uint64_t
parse_env_var_value
(
const
char
*
vp
)
{
return
std
::
strtoull
(
vp
,
nullptr
,
0
);
}
};
template
<
>
struct
ParseEnvVal
<
std
::
string
>
{
static
std
::
string
parse_env_var_value
(
const
char
*
vp
)
{
return
std
::
string
{
vp
};
}
};
template
<
typename
T
>
struct
EnvVar
{
private:
T
value
{};
bool
is_unset
=
true
;
public:
const
T
&
GetValue
()
const
{
return
value
;
}
bool
IsUnset
()
const
{
return
is_unset
;
}
void
Unset
()
{
is_unset
=
true
;
}
void
UpdateValue
(
const
T
&
val
)
{
is_unset
=
false
;
value
=
val
;
}
explicit
EnvVar
(
const
char
*
const
name
,
const
T
&
def_val
)
{
// NOLINTNEXTLINE (concurrency-mt-unsafe)
const
char
*
vp
=
std
::
getenv
(
name
);
if
(
vp
!=
nullptr
)
// a value was provided
{
is_unset
=
false
;
value
=
ParseEnvVal
<
T
>::
parse_env_var_value
(
vp
);
}
else
// no value provided, use default value
{
value
=
def_val
;
}
}
};
}
// end namespace internal
// static inside function hides the variable and provides
// thread-safety/locking
// Used in global namespace
#define CK_DECLARE_ENV_VAR(name, type, default_val) \
namespace ck::env { \
struct name \
{ \
static_assert(std::is_same_v<name, ::ck::env::name>, \
"CK_DECLARE_ENV* must be used in the global namespace"); \
using value_type = type; \
static ck::internal::EnvVar<type>& Ref() \
{ \
static ck::internal::EnvVar<type> var{#name, default_val}; \
return var; \
} \
}; \
}
#define CK_DECLARE_ENV_VAR_BOOL(name) CK_DECLARE_ENV_VAR(name, bool, false)
#define CK_DECLARE_ENV_VAR_UINT64(name) CK_DECLARE_ENV_VAR(name, uint64_t, 0)
#define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "")
#define ENV(name) \
ck::env::name {}
template
<
class
EnvVar
>
inline
const
std
::
string
&
EnvGetString
(
EnvVar
)
{
static_assert
(
std
::
is_same_v
<
typename
EnvVar
::
value_type
,
std
::
string
>
);
return
EnvVar
::
Ref
().
GetValue
();
}
template
<
class
EnvVar
>
inline
bool
EnvIsEnabled
(
EnvVar
)
{
static_assert
(
std
::
is_same_v
<
typename
EnvVar
::
value_type
,
bool
>
);
return
!
EnvVar
::
Ref
().
IsUnset
()
&&
EnvVar
::
Ref
().
GetValue
();
}
template
<
class
EnvVar
>
inline
bool
EnvIsDisabled
(
EnvVar
)
{
static_assert
(
std
::
is_same_v
<
typename
EnvVar
::
value_type
,
bool
>
);
return
!
EnvVar
::
Ref
().
IsUnset
()
&&
!
EnvVar
::
Ref
().
GetValue
();
}
template
<
class
EnvVar
>
inline
uint64_t
EnvValue
(
EnvVar
)
{
static_assert
(
std
::
is_same_v
<
typename
EnvVar
::
value_type
,
uint64_t
>
);
return
EnvVar
::
Ref
().
GetValue
();
}
template
<
class
EnvVar
>
inline
bool
EnvIsUnset
(
EnvVar
)
{
return
EnvVar
::
Ref
().
IsUnset
();
}
template
<
class
EnvVar
>
void
EnvUnset
(
EnvVar
)
{
EnvVar
::
Ref
().
Unset
();
}
/// updates the cached value of an environment variable
template
<
typename
EnvVar
,
typename
ValueType
>
void
UpdateEnvVar
(
EnvVar
,
const
ValueType
&
val
)
{
static_assert
(
std
::
is_same_v
<
typename
EnvVar
::
value_type
,
ValueType
>
);
EnvVar
::
Ref
().
UpdateValue
(
val
);
}
template
<
typename
EnvVar
>
void
UpdateEnvVar
(
EnvVar
,
const
std
::
string_view
&
val
)
{
EnvVar
::
Ref
().
UpdateValue
(
ck
::
internal
::
ParseEnvVal
<
typename
EnvVar
::
value_type
>::
parse_env_var_value
(
val
.
data
()));
}
}
// namespace ck
include/ck/utility/flush_icache.hpp
0 → 100644
View file @
036c5234
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
namespace
ck
{
static
__global__
void
flush_icache
()
{
asm
__volatile__
(
"s_icache_inv
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
::
:
);
}
}
// namespace ck
include/ck/utility/loop_scheduler.hpp
View file @
036c5234
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ostream>
#pragma once
...
...
include/ck/utility/sequence.hpp
View file @
036c5234
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/functional.hpp"
...
...
@@ -897,3 +899,14 @@ template <index_t NSize, index_t I>
using
uniform_sequence_gen_t
=
typename
uniform_sequence_gen
<
NSize
,
I
>::
type
;
}
// namespace ck
template
<
ck
::
index_t
...
Is
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
Sequence
<
Is
...
>
)
{
using
S
=
ck
::
Sequence
<
Is
...
>
;
os
<<
"{"
;
ck
::
static_for
<
0
,
S
::
Size
()
-
ck
::
Number
<
1
>
{},
1
>
{}(
[
&
](
auto
i
)
{
os
<<
S
::
At
(
i
).
value
<<
", "
;
});
os
<<
S
::
At
(
S
::
Size
()
-
ck
::
Number
<
1
>
{}).
value
<<
"}"
;
return
os
;
}
include/ck/utility/type.hpp
View file @
036c5234
...
...
@@ -40,23 +40,10 @@ inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
Y
y
;
// auto t = reinterpret_cast<const Y*>(&x);
// y = *t;
__builtin_memcpy
(
&
y
,
&
x
,
sizeof
(
X
));
return
y
;
#else
union
AsType
{
X
x
;
Y
y
;
};
return
AsType
{
x
}.
y
;
#endif
static_assert
(
__has_builtin
(
__builtin_bit_cast
),
""
);
static_assert
(
sizeof
(
X
)
==
sizeof
(
Y
),
"Do not support cast between different size of type"
);
return
__builtin_bit_cast
(
Y
,
x
);
}
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
036c5234
...
...
@@ -8,7 +8,7 @@
#include "ck/utility/random_gen.hpp"
namespace
ck
{
// Define the common macro for
MI300
models
// Define the common macro for
gfx94x
models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
...
...
include/ck_tile/core/config.hpp
View file @
036c5234
...
...
@@ -154,3 +154,8 @@
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif
// TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
#endif
include/ck_tile/core/numeric/math.hpp
View file @
036c5234
...
...
@@ -536,4 +536,15 @@ float log(float x) { return __logf(x); };
CK_TILE_HOST
float
log
(
float
x
)
{
return
std
::
logf
(
x
);
};
CK_TILE_DEVICE
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
// TODO: this is hacky, we use u16
return
__builtin_amdgcn_sad_u16
(
x
,
y
,
acc
);
}
CK_TILE_HOST
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
}
}
// namespace ck_tile
Prev
1
2
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment