Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0a808724
Commit
0a808724
authored
Dec 09, 2022
by
aska-0096
Browse files
Tidy up + format
parent
289f15de
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
634 additions
and
368 deletions
+634
-368
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+132
-119
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+108
-106
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+302
-129
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+92
-14
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
0a808724
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
0a808724
...
@@ -38,8 +38,10 @@ __global__ void
...
@@ -38,8 +38,10 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
// const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
c_grid_desc_mblock_mperblock_nblock_nperblock
,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
...
@@ -49,18 +51,17 @@ __global__ void
...
@@ -49,18 +51,17 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
,
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -75,50 +76,49 @@ __global__ void
...
@@ -75,50 +76,49 @@ __global__ void
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx1100__))
}
}
template
<
template
<
index_t
BlockSize
,
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K0PerBlock
,
index_t
MPerWmma
,
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
NPerWmma
,
index_t
K1Value
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
1
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_wmma
struct
GridwiseGemm_k0mk1_k0nk1_mn_wmma
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -202,17 +202,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -202,17 +202,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
math
::
integer_least_multiple
(
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
}
}
...
@@ -308,18 +310,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -308,18 +310,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
WmmaK
=
16
;
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
using
BlockwiseGemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
<
BlockSize
,
using
BlockwiseGemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
<
FloatAB
,
BlockSize
,
FloatAcc
,
FloatAB
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
FloatAcc
,
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
MPerWmma
,
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
NPerWmma
,
MPerWmma
,
MRepeat
,
NPerWmma
,
NRepeat
,
MRepeat
,
KPack
>
;
NRepeat
,
KPack
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_m_n
);
return
BlockwiseGemm
::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_m_n
);
}
}
// Per pixel
// Per pixel
...
@@ -362,18 +367,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -362,18 +367,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
WmmaK
=
16
;
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
using
BlockwiseGemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
<
BlockSize
,
using
BlockwiseGemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
<
FloatAB
,
BlockSize
,
FloatAcc
,
FloatAB
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
FloatAcc
,
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
MPerWmma
,
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
NPerWmma
,
MPerWmma
,
MRepeat
,
NPerWmma
,
NRepeat
,
MRepeat
,
KPack
>
;
NRepeat
,
KPack
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
(
c_grid_desc_m_n
);
return
BlockwiseGemm
::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
(
c_grid_desc_m_n
);
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -402,11 +410,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -402,11 +410,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
c_grid_desc_m_n
);
}
}
// using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup = remove_cvref_t<decltype(
// using
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// = remove_cvref_t<decltype(
// MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
// MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
// CGridDesc_M_N{}))>;
// CGridDesc_M_N{}))>;
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
...
@@ -419,15 +429,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -419,15 +429,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
// const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup&
// const
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup&
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
// clang-format off
// clang-format off
/*******************************************************************************/
/*******************************************************************************/
// Memory buffer zone.
// Memory buffer zone.
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -453,12 +464,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -453,12 +464,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
// printf("K0 = %d, M = %d, K1 = %d\n", K0, a_grid_desc_k0_m_k1.GetLength(I1), (a_grid_desc_k0_m_k1.GetLength(I2))());
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// printf("blockdesc: K0 = %d, M = %d, K1 = %d\n", (a_block_desc_k0perblock_mperblock_k1.GetLength(I0))(),
// (a_block_desc_k0perblock_mperblock_k1.GetLength(I1))(), (a_block_desc_k0perblock_mperblock_k1.GetLength(I2))());
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
...
@@ -532,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -532,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
<
BlockSize
,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
_CShuffle
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
...
@@ -838,19 +846,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -838,19 +846,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
if
constexpr
(
access_id
<
num_access
-
1
)
if
constexpr
(
access_id
<
num_access
-
1
)
{
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// CONFIRMED
// printf("c_global_step = (%d, %d, %d, %d)\n",
// c_global_step[Number<0>{}],
// c_global_step[Number<1>{}],
// c_global_step[Number<2>{}],
// c_global_step[Number<3>{}]);
// move on C
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
}
});
});
}
}
// clang-format on
// clang-format on
}
}
};
};
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
0a808724
This diff is collapsed.
Click to expand it.
include/ck/utility/amd_wmma.hpp
View file @
0a808724
...
@@ -8,6 +8,8 @@
...
@@ -8,6 +8,8 @@
// TODO: Add arch limitation
// TODO: Add arch limitation
namespace
ck
{
namespace
ck
{
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32
;
struct
intrin_wmma_f32_16x16x16_f16_w32
;
...
@@ -23,20 +25,6 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
...
@@ -23,20 +25,6 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w64
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w64
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
}
};
// src: bf16, dst: fp32
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32
;
struct
intrin_wmma_f32_16x16x16_bf16_w32
;
...
@@ -111,5 +99,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
...
@@ -111,5 +99,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
}
}
};
};
/********************************WAVE64 MODE***********************************************/
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w64
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w64
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w64
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w64
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
}
};
// src: fp16, dst: fp16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w64
;
template
<
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w64
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
half8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: bf16, dst: bf16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w64
;
template
<
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w64
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
bhalf8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w64
;
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w64
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64
(
neg_a
,
bit_cast
<
int32x4_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x4_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
clamp
);
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment