Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
87a75734
Commit
87a75734
authored
Dec 14, 2020
by
Jing Zhang
Browse files
adding xdlops
parent
7972ab17
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
3483 additions
and
0 deletions
+3483
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+199
-0
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+237
-0
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
...lude/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
+1317
-0
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+1069
-0
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+468
-0
driver/include/gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp
+193
-0
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
87a75734
#ifndef CK_GRIDWISE_GROUP_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_GROUP_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "gridwise_gemm_xdlops_fp16_bfp16.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
class
AccFloat
,
class
CFloat
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
G
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmKPack
,
class
GemmABlockCopyThreadSliceLengths_GemmG_GemmK_GemmM_GemmKPack
,
class
GemmABlockCopyThreadClusterLengths_GemmG_GemmK_GemmM_GemmKPack
,
class
GemmABlockCopyThreadClusterArrangeOrder
,
class
GemmABlockCopySrcAccessOrder
,
class
GemmABlockCopyDstAccessOrder
,
index_t
GemmABlockCopySrcDataPerRead_GemmKPack
,
index_t
GemmABlockCopyDstDataPerWrite_GemmKPack
,
class
GemmBBlockCopyThreadSliceLengths_GemmG_GemmK_GemmN_GemmKPack
,
class
GemmBBlockCopyThreadClusterLengths_GemmG_GemmK_GemmN_GemmKPack
,
class
GemmBBlockCopyThreadClusterArrangeOrder
,
class
GemmBBlockCopySrcAccessOrder
,
class
GemmBBlockCopyDstAccessOrder
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPack
,
WorkgroupScheduleOrder
WorkgroupSchdOrder
>
struct
GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
ABFloat
*
const
__restrict__
p_in_global
,
const
ABFloat
*
const
__restrict__
p_wei_global
,
CFloat
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_cpergroup_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_cpergroup_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_cpergroup_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
CPerGroup
=
C
/
G
;
constexpr
index_t
KPerGroup
=
K
/
G
;
static_assert
(
CPerGroup
==
wei_k_cpergroup_y_x_global_desc
.
GetLengths
()[
1
],
"wrong!"
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GemmG
=
G
;
constexpr
index_t
GemmM
=
KPerGroup
;
constexpr
index_t
GemmN
=
N
*
Ho
*
Wo
;
constexpr
index_t
GemmKTotal
=
CPerGroup
*
Y
*
X
;
static_assert
(
GemmKTotal
%
GemmKPack
==
0
,
"wrong! GemmKTotal should be multiple of GemmKPack"
);
constexpr
index_t
GemmK
=
GemmKTotal
/
GemmKPack
;
static_assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
// construct tensor descriptor for group convolution
constexpr
auto
in_g_n_cpergroup_hi_wi_global_desc
=
make_native_tensor_descriptor
(
Sequence
<
G
,
N
,
CPerGroup
,
Hi
,
Wi
>
{},
Sequence
<
CPerGroup
*
Hi
*
Wi
,
C
*
Hi
*
Wi
,
Hi
*
Wi
,
Wi
,
1
>
{});
constexpr
auto
wei_g_kpergroup_cpergroup_y_x_global_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
G
,
KPerGroup
,
CPerGroup
,
Y
,
X
>
{});
constexpr
auto
out_g_n_kpergroup_ho_wo_global_desc
=
make_native_tensor_descriptor
(
Sequence
<
G
,
N
,
KPerGroup
,
Ho
,
Wo
>
{},
Sequence
<
KPerGroup
*
Ho
*
Wo
,
K
*
Ho
*
Wo
,
Ho
*
Wo
,
Wo
,
1
>
{});
// input tensor
constexpr
auto
in_g_n_cpergroup_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_g_n_cpergroup_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
G
>
{},
PassThrough
<
N
>
{},
PassThrough
<
CPerGroup
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
>
{}));
constexpr
index_t
Hip
=
in_g_n_cpergroup_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Wip
=
in_g_n_cpergroup_hip_wip_global_desc
.
GetLengths
()[
4
];
constexpr
auto
in_g_n_cpergroup_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_g_n_cpergroup_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
G
>
{},
PassThrough
<
N
>
{},
PassThrough
<
CPerGroup
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{}));
constexpr
auto
in_gemmg_gemmktotal_gemmn_global_desc
=
transform_tensor_descriptor
(
in_g_n_cpergroup_y_ho_x_wo_global_desc
,
make_tuple
(
PassThrough
<
G
>
{},
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
,
3
,
5
>
{},
Sequence
<
1
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
constexpr
auto
in_gemmg_gemmk_gemmn_gemmkpack_global_desc
=
transform_tensor_descriptor
(
in_gemmg_gemmktotal_gemmn_global_desc
,
make_tuple
(
PassThrough
<
GemmG
>
{},
UnMerge
<
Sequence
<
GemmK
,
GemmKPack
>>
{},
PassThrough
<
GemmN
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}));
// weight tensor
constexpr
auto
wei_gemmg_gemmm_gemmktotal_global_desc
=
unfold_tensor_descriptor
(
wei_g_kpergroup_cpergroup_y_x_global_desc
,
Number
<
2
>
{},
Number
<
4
>
{});
constexpr
auto
wei_gemmg_gemmk_gemmm_gemmkpack_global_desc
=
transform_tensor_descriptor
(
wei_gemmg_gemmm_gemmktotal_global_desc
,
make_tuple
(
PassThrough
<
GemmG
>
{},
PassThrough
<
GemmM
>
{},
UnMerge
<
Sequence
<
GemmK
,
GemmKPack
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
1
,
3
>
{}));
// output tensor
constexpr
auto
out_gemmg_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
out_g_n_kpergroup_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
G
>
{},
PassThrough
<
KPerGroup
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// gridwise batch-GEMM
constexpr
auto
gridwise_gemm
=
GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
<
GridSize
,
BlockSize
,
ABFloat
,
AccFloat
,
CFloat
,
decltype
(
wei_gemmg_gemmk_gemmm_gemmkpack_global_desc
),
decltype
(
in_gemmg_gemmk_gemmn_gemmkpack_global_desc
),
decltype
(
out_gemmg_gemmm_gemmn_global_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmABlockCopyThreadSliceLengths_GemmG_GemmK_GemmM_GemmKPack
,
GemmABlockCopyThreadClusterLengths_GemmG_GemmK_GemmM_GemmKPack
,
GemmABlockCopyThreadClusterArrangeOrder
,
GemmABlockCopySrcAccessOrder
,
GemmABlockCopyDstAccessOrder
,
3
,
// src vector read dimension of A matrix is GemmKPack
GemmABlockCopySrcDataPerRead_GemmKPack
,
GemmABlockCopyDstDataPerWrite_GemmKPack
,
GemmBBlockCopyThreadSliceLengths_GemmG_GemmK_GemmN_GemmKPack
,
GemmBBlockCopyThreadClusterLengths_GemmG_GemmK_GemmN_GemmKPack
,
GemmBBlockCopyThreadClusterArrangeOrder
,
GemmBBlockCopySrcAccessOrder
,
GemmBBlockCopyDstAccessOrder
,
2
,
// Src vetor read diemsnion of B matrix is GemmN
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmKPack
,
InMemoryDataOperation
::
Set
,
WorkgroupSchdOrder
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
0 → 100644
View file @
87a75734
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "xdlops_gemm.hpp"
#include "threadwise_gemm.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
Float
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmMWaves
,
index_t
GemmNWaves
,
index_t
GemmDataPerReadA
,
// \todo unused parameter, remove
index_t
GemmDataPerReadB
// \todo unused parameter, remove
>
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
#if CK_WORKAROUND_SWDEV_241664
static
constexpr
index_t
MRepeats
=
(
GemmMPerWave
>
64
)
?
(
GemmMPerWave
/
64
)
:
1
;
static
constexpr
index_t
NRepeats
=
(
GemmNPerWave
>
64
)
?
(
GemmNPerWave
/
64
)
:
1
;
static
constexpr
index_t
MPerXdlops
=
(
GemmMPerWave
>
64
)
?
64
:
GemmMPerWave
;
static
constexpr
index_t
NPerXdlops
=
(
GemmNPerWave
>
64
)
?
64
:
GemmNPerWave
;
static
constexpr
auto
XdlopsGemm
=
XdlopsGemm_t
<
Float
,
MPerXdlops
,
NPerXdlops
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
#else
#if CK_USE_AMD_XDLOPS_INLINE_ASM
/// \to-do add inline support for vector type c
static_assert
(
false
,
"Does not support inline asm for vector type c"
)
#else
static
constexpr
auto
XdlopsGemm
=
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
#endif
#endif
index_t
mMyWaveOffsetA
;
index_t
mMyWaveOffsetB
;
static
constexpr
index_t
WaveSize
=
64
;
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm
.
GetOutputLayout
();
}
#if CK_WORKAROUND_SWDEV_241664
template
<
index_t
MRepeats_
=
MRepeats
,
index_t
NRepeats_
=
NRepeats
>
__device__
constexpr
auto
CreateOutputVecZero
()
const
;
template
<
>
__device__
constexpr
auto
CreateOutputVecZero
<
2
,
1
>
()
const
{
return
c_vec32_2_2_t
::
CreateVecZero
();
}
template
<
>
__device__
constexpr
auto
CreateOutputVecZero
<
1
,
2
>
()
const
{
return
c_vec32_2_2_t
::
CreateVecZero
();
}
template
<
>
__device__
constexpr
auto
CreateOutputVecZero
<
1
,
1
>
()
const
{
return
XdlopsGemm
.
GetOutputLayout
().
CreateOutputVecZero
();
}
#else
__device__
constexpr
auto
CreateOutputVecZero
()
const
{
return
XdlopsGemm
.
GetOutputLayout
().
CreateOutputVecZero
();
}
#endif
__device__
constexpr
auto
GetNumBlks
()
const
{
#if CK_WORKAROUND_SWDEV_241664
return
XdlopsGemm
.
GetOutputLayout
().
GetNumBlks
()
*
MRepeats
*
NRepeats
;
#else
return
XdlopsGemm
.
GetOutputLayout
().
GetNumBlks
();
#endif
}
__device__
constexpr
auto
GetBlkSize
()
const
{
return
XdlopsGemm
.
GetOutputLayout
().
GetBlkSize
();
}
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
()
{
static_assert
(
BlockMatrixA
::
NRow
()
==
BlockMatrixB
::
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
static_assert
(
GemmMPerWave
*
GemmMWaves
==
M
,
"GemmMWaves * GemmMPerWave != M"
);
static_assert
(
GemmNPerWave
*
GemmNWaves
==
N
,
"GemmNWaves * GemmNPerWave != N"
);
static_assert
(
BlockSize
==
GemmMWaves
*
GemmNWaves
*
WaveSize
,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize
\n
"
);
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
index_t
waveId_m
=
waveId
/
GemmNWaves
;
const
index_t
waveId_n
=
waveId
%
GemmNWaves
;
mMyWaveOffsetA
=
waveId_m
*
GemmMPerWave
;
mMyWaveOffsetB
=
waveId_n
*
GemmNPerWave
;
}
#if CK_WORKAROUND_SWDEV_241664
template
<
index_t
MRepeats_
,
index_t
NRepeats_
>
struct
WithMNRepeats
;
template
<
>
struct
WithMNRepeats
<
2
,
1
>
{
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
static
FloatC
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
p_c_thread
)
{
p_c_thread
.
s
.
x
.
l
=
XdlopsGemm
.
template
Run
<
M
,
N
,
K
>(
p_a_block
,
p_b_block
,
p_c_thread
.
s
.
x
.
l
);
p_c_thread
.
s
.
y
.
l
=
XdlopsGemm
.
template
Run
<
M
,
N
,
K
>(
p_a_block
+
MPerXdlops
,
p_b_block
,
p_c_thread
.
s
.
y
.
l
);
return
p_c_thread
;
}
};
template
<
>
struct
WithMNRepeats
<
1
,
2
>
{
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
static
FloatC
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
p_c_thread
)
{
p_c_thread
.
s
.
x
.
l
=
XdlopsGemm
.
template
Run
<
M
,
N
,
K
>(
p_a_block
,
p_b_block
,
p_c_thread
.
s
.
x
.
l
);
p_c_thread
.
s
.
y
.
l
=
XdlopsGemm
.
template
Run
<
M
,
N
,
K
>(
p_a_block
,
p_b_block
+
NPerXdlops
,
p_c_thread
.
s
.
y
.
l
);
return
p_c_thread
;
}
};
template
<
>
struct
WithMNRepeats
<
1
,
1
>
{
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
static
FloatC
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
p_c_thread
)
{
return
XdlopsGemm
.
template
Run
<
M
,
N
,
K
>(
p_a_block
,
p_b_block
,
p_c_thread
);
}
};
#endif
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
p_c_thread
)
const
{
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
constexpr
index_t
K
=
BlockMatrixA
::
NRow
();
#if CK_WORKAROUND_SWDEV_241664
return
WithMNRepeats
<
MRepeats
,
NRepeats
>::
template
Run
<
M
,
N
,
K
>(
&
p_a_block
[
mMyWaveOffsetA
],
&
p_b_block
[
mMyWaveOffsetB
],
p_c_thread
);
#else
return
XdlopsGemm
.
template
Run
<
M
,
N
,
K
>(
&
p_a_block
[
mMyWaveOffsetA
],
&
p_b_block
[
mMyWaveOffsetB
],
p_c_thread
);
#endif
}
template
<
index_t
AStride
=
GemmMPerWave
,
index_t
BStride
=
GemmNPerWave
>
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
i
)
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
#if CK_WORKAROUND_SWDEV_241664
const
index_t
xdlops_i
=
i
/
XdlopsGemm
.
GetOutputLayout
().
GetNumBlks
();
const
index_t
j
=
i
%
XdlopsGemm
.
GetOutputLayout
().
GetNumBlks
();
const
index_t
m
=
xdlops_i
/
NRepeats
;
const
index_t
n
=
xdlops_i
%
NRepeats
;
const
auto
thread_mtx_on_blk
=
XdlopsGemm
.
GetBeginOfThreadBlk
(
j
);
const
index_t
col
=
(
waveId
%
GemmNWaves
)
*
BStride
+
n
*
NPerXdlops
+
thread_mtx_on_blk
.
col
;
const
index_t
row
=
(
waveId
/
GemmNWaves
)
*
AStride
+
m
*
MPerXdlops
+
thread_mtx_on_blk
.
row
;
#else
const
auto
thread_mtx_on_blk
=
XdlopsGemm
.
GetBeginOfThreadBlk
(
i
);
const
index_t
col
=
(
waveId
%
GemmNWaves
)
*
BStride
+
thread_mtx_on_blk
.
col
;
const
index_t
row
=
(
waveId
/
GemmNWaves
)
*
AStride
+
thread_mtx_on_blk
.
row
;
#endif
return
MatrixIndex
{
row
,
col
};
}
__device__
constexpr
auto
GetThreadMatrixCDescriptor
()
const
{
const
index_t
total_reg_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
return
make_ConstantMatrixDescriptor_packed
(
Number
<
total_reg_size
>
{},
Number
<
1
>
{});
}
__device__
void
XdlopsMatrixCSetZero
()
const
{
XdlopsGemm
.
SetZeroXdlopsRegs
();
}
template
<
class
FloatC
>
__device__
void
XdlopsMatrixCRead
(
FloatC
*
__restrict__
p_c_thread
)
const
{
XdlopsGemm
.
ReadXdlopsRegs
(
p_c_thread
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
0 → 100644
View file @
87a75734
#ifndef CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_FP16_BFP16_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy_v2.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
namespace
ck
{
enum
WorkgroupScheduleOrder
{
MBlock1NBlock0
,
NBlock1MBlock0
};
template
<
index_t
Gi
,
index_t
MBlockWork
,
index_t
NBlockWork
,
WorkgroupScheduleOrder
WorkgroupSchdOrder
>
struct
make_batch_block_work_sequence
;
template
<
index_t
Gi
,
index_t
MBlockWork
,
index_t
NBlockWork
>
struct
make_batch_block_work_sequence
<
Gi
,
MBlockWork
,
NBlockWork
,
MBlock1NBlock0
>
{
__device__
constexpr
auto
get
()
{
return
Sequence
<
Gi
,
MBlockWork
,
NBlockWork
>
{};
}
};
template
<
index_t
Gi
,
index_t
MBlockWork
,
index_t
NBlockWork
>
struct
make_batch_block_work_sequence
<
Gi
,
MBlockWork
,
NBlockWork
,
NBlock1MBlock0
>
{
__device__
constexpr
auto
get
()
{
return
Sequence
<
Gi
,
NBlockWork
,
MBlockWork
>
{};
}
};
template
<
index_t
MBlockWork
,
index_t
NBlockWork
,
WorkgroupScheduleOrder
WorkgroupSchdOrder
>
struct
make_block_work_sequence
;
template
<
index_t
MBlockWork
,
index_t
NBlockWork
>
struct
make_block_work_sequence
<
MBlockWork
,
NBlockWork
,
MBlock1NBlock0
>
{
__device__
constexpr
auto
get
()
{
return
Sequence
<
MBlockWork
,
NBlockWork
>
{};
}
};
template
<
index_t
MBlockWork
,
index_t
NBlockWork
>
struct
make_block_work_sequence
<
MBlockWork
,
NBlockWork
,
NBlock1MBlock0
>
{
__device__
constexpr
auto
get
()
{
return
Sequence
<
NBlockWork
,
MBlockWork
>
{};
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
class
AccFloat
,
class
CFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
GemmDataPerReadM
,
index_t
GemmDataPerReadN
,
class
ABlockCopyThreadSliceLengths_K_M_KPACK
,
class
ABlockCopyThreadClusterLengths_K_M_KPACK
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_KPACK
,
class
BBlockCopyThreadSliceLengths_K_N_KPACK
,
class
BBlockCopyThreadClusterLengths_K_N_KPACK
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_KPACK
,
InMemoryDataOperation
OutputMemOp
,
WorkgroupScheduleOrder
WorkgroupSchdOrder
,
index_t
ABlockCopySrcDataStride
=
1
,
index_t
BBlockCopySrcDataStride
=
1
>
struct
GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__
void
Run
(
const
ABFloat
*
const
__restrict__
p_a_global
,
const
ABFloat
*
const
__restrict__
p_b_global
,
CFloat
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
b_k_n_kpack_global_desc
=
BGlobalDesc
{};
constexpr
auto
a_k_m_kpack_global_desc
=
AGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
K
=
b_k_n_kpack_global_desc
.
GetLengths
()[
0
];
constexpr
auto
N
=
b_k_n_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
M
=
a_k_m_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
KPACK
=
b_k_n_kpack_global_desc
.
GetLengths
()[
2
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
index_t
MWaves
=
MPerBlock
/
MPerWave
;
constexpr
index_t
NWaves
=
NPerBlock
/
NPerWave
;
constexpr
auto
block_work_sequence
=
make_block_work_sequence
<
MBlockWork
,
NBlockWork
,
WorkgroupSchdOrder
>
{}.
get
();
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
block_work_sequence
);
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
0
]
*
MPerBlock
)
:
(
block_work_id
[
1
]
*
MPerBlock
);
const
index_t
b_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
1
]
*
NPerBlock
)
:
(
block_work_id
[
0
]
*
NPerBlock
);
// LDS mem
constexpr
index_t
max_align
=
math
::
lcm
(
BBlockCopyDstDataPerWrite_KPACK
,
ABlockCopyDstDataPerWrite_KPACK
,
KPACK
*
GemmDataPerReadM
,
KPACK
*
GemmDataPerReadN
);
// LDS
// be careful of LDS alignment
constexpr
auto
a_k_m_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_k_m_kpack_global_desc
),
decltype
(
a_k_m_kpack_block_desc
),
decltype
(
a_k_m_kpack_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_K_M_KPACK
,
ABlockCopyThreadClusterLengths_K_M_KPACK
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (M dimension)
2
,
// Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
ABlockCopySrcDataStride
>
({
0
,
k_block_data_on_global
,
0
},
{
0
,
0
,
0
});
constexpr
auto
b_k_n_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
// input blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_k_n_kpack_global_desc
),
decltype
(
b_k_n_kpack_block_desc
),
decltype
(
b_k_n_kpack_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_K_N_KPACK
,
BBlockCopyThreadClusterLengths_K_N_KPACK
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (N dimension)
2
,
// Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
BBlockCopySrcDataStride
>
({
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
ABFloat
,
MPerWave
,
NPerWave
,
MWaves
,
NWaves
,
GemmDataPerReadM
,
GemmDataPerReadN
>
{};
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_k_m_kpack_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k_n_kpack_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
ABFloat
p_a_block_double
[
2
*
a_block_space
];
__shared__
ABFloat
p_b_block_double
[
2
*
b_block_space
];
// get zero-initialized output register of vector type
auto
c_thread_vec
=
blockwise_gemm
.
CreateOutputVecZero
();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
using
blockwise_a_copy_src_step
=
Sequence
<
KPerBlock
,
0
,
0
>
;
using
blockwise_b_copy_src_step
=
Sequence
<
KPerBlock
,
0
,
0
>
;
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
ABFloat
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
ABFloat
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
ABFloat
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
ABFloat
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how fp16/bfloat16 datatypes are
// processed in gemm operation. fp16 type packs 4 fp16 values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single fp16 to 4 packed fp16/2 packed bfloat16
// respectively.
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_now
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_now
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on current data
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
+
a_block_space
);
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
+
b_block_space
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
}
// copy output: register to global memory
{
constexpr
auto
OutputLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
K0
=
OutputLayout
.
M1
();
constexpr
index_t
K1
=
OutputLayout
.
N1
();
constexpr
index_t
K2
=
OutputLayout
.
M0
();
constexpr
auto
out_k0_k1_k2_b_global_desc
=
transform_tensor_descriptor
(
c_m_n_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
K0
,
K1
,
K2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}));
// src descriptor
constexpr
auto
out_k0_k1_k2_b_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
K0
,
1
,
K2
,
1
>
{});
using
OutThreadCopySliceLengths
=
Sequence
<
K0
,
1
,
K2
,
1
>
;
constexpr
index_t
BlkSize
=
OutputLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
OutputLayout
.
GetNumBlks
();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_k2_b_thread_desc
),
decltype
(
out_k0_k1_k2_b_global_desc
),
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
1
,
1
,
AddressSpace
::
Vgpr
,
is_same
<
AccFloat
,
CFloat
>::
value
?
AddressSpace
::
Global
:
AddressSpace
::
Generic
,
OutputMemOp
>
({
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
(
K2
*
K1
),
k_thread_data_on_global
%
(
K2
*
K1
)
/
K2
,
k_thread_data_on_global
%
K2
,
b_thread_data_on_global
})
.
Run
(
c_thread_vec
.
n
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
class
AccFloat
,
class
CFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
GemmDataPerReadM
,
index_t
GemmDataPerReadN
,
class
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_KPACK
,
class
BBlockCopyThreadSliceLengths_G_K_N_KPACK
,
class
BBlockCopyThreadClusterLengths_G_K_N_KPACK
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_KPACK
,
InMemoryDataOperation
OutputMemOp
,
WorkgroupScheduleOrder
WorkgroupSchdOrder
,
index_t
ABlockCopySrcDataStride
=
1
,
index_t
BBlockCopySrcDataStride
=
1
>
struct
GridwiseBatchedGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__
void
Run
(
const
ABFloat
*
const
__restrict__
p_a_global
,
const
ABFloat
*
const
__restrict__
p_b_global
,
CFloat
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
a_g_k_m_kpack_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_g_k_n_kpack_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_g_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
Gi
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
0
];
constexpr
auto
Go
=
c_g_m_n_global_desc
.
GetLengths
()[
0
];
constexpr
auto
K
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
2
];
constexpr
auto
M
=
a_g_k_m_kpack_global_desc
.
GetLengths
()[
2
];
constexpr
auto
KPACK
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
3
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
index_t
MWaves
=
MPerBlock
/
MPerWave
;
constexpr
index_t
NWaves
=
NPerBlock
/
NPerWave
;
constexpr
auto
block_work_sequence
=
make_batch_block_work_sequence
<
Gi
,
MBlockWork
,
NBlockWork
,
WorkgroupSchdOrder
>
{}.
get
();
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
block_work_sequence
);
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
group_id
=
block_work_id
[
0
];
const
index_t
m_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
1
]
*
MPerBlock
)
:
(
block_work_id
[
2
]
*
MPerBlock
);
const
index_t
n_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
2
]
*
NPerBlock
)
:
(
block_work_id
[
1
]
*
NPerBlock
);
// LDS mem
constexpr
index_t
max_align
=
math
::
lcm
(
BBlockCopyDstDataPerWrite_KPACK
,
ABlockCopyDstDataPerWrite_KPACK
,
KPACK
*
GemmDataPerReadM
,
KPACK
*
GemmDataPerReadN
);
// LDS
// be careful of LDS alignment
constexpr
auto
a_g_k_m_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_g_k_m_kpack_global_desc
),
decltype
(
a_g_k_m_kpack_block_desc
),
decltype
(
a_g_k_m_kpack_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (K dimension)
3
,
// Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
ABlockCopySrcDataStride
>
({
group_id
,
0
,
m_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
constexpr
auto
b_g_k_n_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
NPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
// input blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_g_k_n_kpack_global_desc
),
decltype
(
b_g_k_n_kpack_block_desc
),
decltype
(
b_g_k_n_kpack_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_G_K_N_KPACK
,
BBlockCopyThreadClusterLengths_G_K_N_KPACK
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (K dimension)
3
,
// Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead
,
// N dimension
BBlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
BBlockCopySrcDataStride
>
({
group_id
,
0
,
n_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
ABFloat
,
MPerWave
,
NPerWave
,
MWaves
,
NWaves
,
GemmDataPerReadM
,
GemmDataPerReadN
>
{};
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_g_k_m_kpack_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_g_k_n_kpack_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
ABFloat
p_a_block_double
[
2
*
a_block_space
];
__shared__
ABFloat
p_b_block_double
[
2
*
b_block_space
];
// get zero-initialized output register of vector type
auto
c_thread_vec
=
blockwise_gemm
.
CreateOutputVecZero
();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
using
blockwise_a_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
;
using
blockwise_b_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
;
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
ABFloat
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
ABFloat
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
ABFloat
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
ABFloat
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how fp16/bfloat16 datatypes are
// processed in gemm operation. fp16 type packs 4 fp16 values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single fp16 to 4 packed fp16/2 packed bfloat16
// respectively.
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_now
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_now
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on current data
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
+
a_block_space
);
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
+
b_block_space
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
}
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr
auto
CLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M2
=
CLayout
.
M0
();
constexpr
auto
c_g_m0_m1_m2_n_global_desc
=
transform_tensor_descriptor
(
c_g_m_n_global_desc
,
make_tuple
(
PassThrough
<
Go
>
{},
UnMerge
<
Sequence
<
M0
,
M1
,
M2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
// src descriptor
constexpr
auto
c_g_m0_m1_m2_n_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
{});
using
CThreadCopySliceLengths
=
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
;
constexpr
index_t
BlkSize
=
CLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
CLayout
.
GetNumBlks
();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_g_m0_m1_m2_n_thread_desc
),
decltype
(
c_g_m0_m1_m2_n_global_desc
),
CThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
4
,
1
,
1
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
OutputMemOp
>
(
{
0
,
0
,
0
,
0
,
0
},
{
group_id
,
m_thread_data_on_global
/
(
M2
*
M1
),
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
M2
,
n_thread_data_on_global
})
.
Run
(
c_thread_vec
.
n
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
class
AccFloat
,
class
CFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
NPerWave
,
class
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_KPACK
,
class
BBlockCopyThreadSliceLengths_G_K_N_KPACK
,
class
BBlockCopyThreadClusterLengths_G_K_N_KPACK
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_KPACK
,
InMemoryDataOperation
CGlobalMemoryOp
,
WorkgroupScheduleOrder
WorkgroupSchdOrder
>
struct
GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
{
__device__
void
Run
(
const
ABFloat
*
const
__restrict__
p_a_global
,
const
ABFloat
*
const
__restrict__
p_b_global
,
CFloat
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
a_g_k_m_kpack_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_g_k_n_kpack_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_g_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
G
=
c_g_m_n_global_desc
.
GetLengths
()[
0
];
constexpr
auto
M
=
c_g_m_n_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
c_g_m_n_global_desc
.
GetLengths
()[
2
];
constexpr
auto
K
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
KPack
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
3
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
index_t
MWavePerBlock
=
MPerBlock
/
MPerWave
;
constexpr
index_t
NWavePerBlock
=
NPerBlock
/
NPerWave
;
constexpr
auto
block_work_sequence
=
make_batch_block_work_sequence
<
G
,
MBlockWork
,
NBlockWork
,
WorkgroupSchdOrder
>
{}.
get
();
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
block_work_sequence
);
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
g_block_data_on_global
=
block_work_id
[
0
];
const
index_t
m_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
1
]
*
MPerBlock
)
:
(
block_work_id
[
2
]
*
MPerBlock
);
const
index_t
n_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
2
]
*
NPerBlock
)
:
(
block_work_id
[
1
]
*
NPerBlock
);
constexpr
index_t
max_align
=
KPack
;
// LDS be careful of LDS alignment
constexpr
auto
a_g_k_m_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
KPack
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_g_k_m_kpack_global_desc
),
decltype
(
a_g_k_m_kpack_block_desc
),
decltype
(
a_g_k_m_kpack_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
// Src dim to be read in vector form
3
,
// Dst dim to be written in vector form (KPack dimension)
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
g_block_data_on_global
,
0
,
m_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
constexpr
auto
b_g_k_n_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
NPerBlock
,
KPack
>
{},
Number
<
max_align
>
{});
// input blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v5
<
BlockSize
,
decltype
(
b_g_k_n_kpack_global_desc
),
decltype
(
b_g_k_n_kpack_block_desc
),
decltype
(
b_g_k_n_kpack_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_G_K_N_KPACK
,
BBlockCopyThreadClusterLengths_G_K_N_KPACK
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
// Src dim to be read in vector form
3
,
// Dst dim to be written in vector form (KPack dimension)
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
g_block_data_on_global
,
0
,
n_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
ABFloat
,
MPerWave
,
NPerWave
,
MWavePerBlock
,
NWavePerBlock
,
1
,
1
>
{};
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_g_k_m_kpack_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_g_k_n_kpack_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
ABFloat
p_a_block
[
a_block_space
];
__shared__
ABFloat
p_b_block
[
b_block_space
];
// get zero-initialized output register of vector type
auto
c_thread_vec
=
blockwise_gemm
.
CreateOutputVecZero
();
// preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block
);
}
constexpr
auto
blockwise_a_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
{};
constexpr
auto
blockwise_b_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
{};
// main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
KPerBlock
;
k_block_data_begin
+=
KPerBlock
)
{
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
// ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
// load next data from device mem
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
,
True
);
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
);
block_sync_lds
();
// GEMM on current data
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_a_block
);
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_b_block
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
block_sync_lds
();
// store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_block
);
}
// tail
{
block_sync_lds
();
// GEMM on last data
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_a_block
);
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_b_block
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr
auto
CLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M2
=
CLayout
.
M0
();
constexpr
auto
c_g_m0_m1_m2_n_global_desc
=
transform_tensor_descriptor
(
c_g_m_n_global_desc
,
make_tuple
(
PassThrough
<
G
>
{},
UnMerge
<
Sequence
<
M
/
(
M1
*
M2
),
M1
,
M2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
// src descriptor
constexpr
auto
c_g_m0_m1_m2_n_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
{});
using
CThreadCopySliceLengths
=
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
;
constexpr
index_t
BlkSize
=
blockwise_gemm
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
blockwise_gemm
.
GetNumBlks
();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_g_m0_m1_m2_n_thread_desc
),
decltype
(
c_g_m0_m1_m2_n_global_desc
),
CThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
4
,
1
,
1
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryOp
>
(
{
0
,
0
,
0
,
0
,
0
},
{
g_block_data_on_global
,
m_thread_data_on_global
/
(
M2
*
M1
),
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
M2
,
n_thread_data_on_global
})
.
Run
(
c_thread_vec
.
n
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
class
AccFloat
,
class
CFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
BPerWave
,
class
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_KPACK
,
class
BBlockCopyThreadSliceLengths_G_K_N1_B_KPack
,
class
BBlockCopyThreadClusterLengths_G_K_N1_B_KPack
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_KPACK
,
InMemoryDataOperation
CGlobalMemoryOp
,
WorkgroupScheduleOrder
WorkgroupSchdOrder
>
struct
GridwiseBatchGemmXdlops_gkmkpack_gkn1bkpack_gmn_v2
{
__device__
void
Run
(
const
ABFloat
*
const
__restrict__
p_a_global
,
const
ABFloat
*
const
__restrict__
p_b_global
,
CFloat
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
a_g_k_m_kpack_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_g_k_n1_b_kpack_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_g_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
G
=
c_g_m_n_global_desc
.
GetLengths
()[
0
];
constexpr
auto
M
=
c_g_m_n_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
c_g_m_n_global_desc
.
GetLengths
()[
2
];
constexpr
auto
K
=
b_g_k_n1_b_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
in_N1
=
b_g_k_n1_b_kpack_global_desc
.
GetLengths
()[
2
];
constexpr
auto
B
=
b_g_k_n1_b_kpack_global_desc
.
GetLengths
()[
3
];
constexpr
auto
KPack
=
b_g_k_n1_b_kpack_global_desc
.
GetLengths
()[
4
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
index_t
MWavePerBlock
=
MPerBlock
/
MPerWave
;
constexpr
index_t
BWavePerBlock
=
in_N1
;
static_assert
((
G
*
MBlockWork
*
BBlockWork
)
==
GridSize
,
"Invalid GridSize"
);
constexpr
auto
block_work_sequence
=
make_batch_block_work_sequence
<
G
,
MBlockWork
,
BBlockWork
,
WorkgroupSchdOrder
>
{}.
get
();
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
block_work_sequence
);
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
g_block_data_on_global
=
block_work_id
[
0
];
const
index_t
m_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
1
]
*
MPerBlock
)
:
(
block_work_id
[
2
]
*
MPerBlock
);
const
index_t
b_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
2
]
*
BPerBlock
)
:
(
block_work_id
[
1
]
*
BPerBlock
);
constexpr
index_t
max_align
=
KPack
;
// LDS be careful of LDS alignment
constexpr
auto
a_g_k_m_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
KPack
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_g_k_m_kpack_global_desc
),
decltype
(
a_g_k_m_kpack_block_desc
),
decltype
(
a_g_k_m_kpack_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
// Src dim to be read in vector form
3
,
// Dst dim to be written in vector form (KPack dimension)
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
g_block_data_on_global
,
0
,
m_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
constexpr
auto
b_g_k_n1_b_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
in_N1
,
BPerBlock
,
KPack
>
{},
Number
<
max_align
>
{});
// input blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_g_k_n1_b_kpack_global_desc
),
decltype
(
b_g_k_n1_b_kpack_block_desc
),
decltype
(
b_g_k_n1_b_kpack_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_G_K_N1_B_KPack
,
BBlockCopyThreadClusterLengths_G_K_N1_B_KPack
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
// Src dim to be read in vector form
4
,
// Dst dim to be written in vector form (KPack dimension)
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
g_block_data_on_global
,
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, BPerBlock * in_N1] is in LDS
// c_mtx[MPerBlock, BPerBlock * in_N1] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
BPerBlock
*
in_N1
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
ABFloat
,
MPerWave
,
BPerWave
,
MWavePerBlock
,
BWavePerBlock
,
1
,
1
>
{};
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_g_k_m_kpack_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_g_k_n1_b_kpack_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
ABFloat
p_a_block
[
a_block_space
];
__shared__
ABFloat
p_b_block
[
b_block_space
];
// get zero-initialized output register of vector type
auto
c_thread_vec
=
blockwise_gemm
.
CreateOutputVecZero
();
// preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block
);
}
constexpr
auto
blockwise_a_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
{};
constexpr
auto
blockwise_b_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
,
0
>
{};
// main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
KPerBlock
;
k_block_data_begin
+=
KPerBlock
)
{
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
// load next data from device mem
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
,
True
);
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
block_sync_lds
();
// GEMM on current data
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_a_block
);
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_b_block
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
block_sync_lds
();
// store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block
);
}
// tail
{
block_sync_lds
();
// GEMM on last data
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_a_block
);
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_b_block
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr
auto
CLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M2
=
CLayout
.
M0
();
constexpr
auto
c_g_m0_m1_m2_n_global_desc
=
transform_tensor_descriptor
(
c_g_m_n_global_desc
,
make_tuple
(
PassThrough
<
G
>
{},
UnMerge
<
Sequence
<
M
/
(
M1
*
M2
),
M1
,
M2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
// src descriptor
constexpr
auto
c_g_m0_m1_m2_n_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
{});
using
CThreadCopySliceLengths
=
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
;
constexpr
index_t
BlkSize
=
blockwise_gemm
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
blockwise_gemm
.
GetNumBlks
();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
template
GetBeginOfThreadMatrixC
<
MPerWave
,
B
>(
i
);
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_g_m0_m1_m2_n_thread_desc
),
decltype
(
c_g_m0_m1_m2_n_global_desc
),
CThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
4
,
1
,
1
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryOp
>
(
{
0
,
0
,
0
,
0
,
0
},
{
g_block_data_on_global
,
m_thread_data_on_global
/
(
M2
*
M1
),
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
M2
,
n_thread_data_on_global
})
.
Run
(
c_thread_vec
.
n
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
0 → 100644
View file @
87a75734
#ifndef CK_XDLOPS_GEMM_HPP
#define CK_XDLOPS_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "math.hpp"
namespace
ck
{
enum
struct
mfma_instr
{
// fp32
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_16x16x1xf32
,
mfma_f32_4x4x1xf32
,
mfma_f32_32x32x2xf32
,
// k reduction
mfma_f32_16x16x4xf32
,
// k reduction
// fp16
mfma_f32_32x32x4f16
,
mfma_f32_16x16x4f16
,
mfma_f32_4x4x4f16
,
mfma_f32_32x32x8f16
,
// k reduction
mfma_f32_16x16x16f16
,
// k reduction
// bfp16
mfma_f32_32x32x2bf16
,
mfma_f32_16x16x2bf16
,
mfma_f32_4x4x2bf16
,
mfma_f32_32x32x4bf16
,
// k reduction
mfma_f32_16x16x8bf16
,
// k reduction
};
template
<
mfma_instr
instr
>
struct
mfma_info
;
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_32x32x2f32
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_16x16x4f32
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_16x16x1f32
<
MPerXdlops
,
NPerXdlops
>
(
p_a
,
p_b
,
reg_c
);
}
};
// treat 4x4x1 as a single-blk 4x64 mfma
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
4
;
static
constexpr
index_t
m
=
4
;
static
constexpr
index_t
n
=
64
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
8
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_4x4x1f32
<
MPerXdlops
,
NPerXdlops
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
4
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_32x32x4f16
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x8f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
8
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
4
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_32x32x8f16
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x16f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
16
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
4
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_16x16x16f16
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
4
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_16x16x4f16
<
MPerXdlops
,
NPerXdlops
>
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
4
;
static
constexpr
index_t
m
=
4
;
static
constexpr
index_t
n
=
64
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
8
;
static
constexpr
index_t
k_base
=
4
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_4x4x4f16
<
MPerXdlops
,
NPerXdlops
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
2
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
ushort2_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
ushort2_t
*>
(
b
);
return
intrin_mfma_f32_32x32x2bf16
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
2
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
ushort2_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
ushort2_t
*>
(
b
);
return
intrin_mfma_f32_32x32x4bf16
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x8bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
8
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
2
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
ushort2_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
ushort2_t
*>
(
b
);
return
intrin_mfma_f32_16x16x8bf16
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
2
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
ushort2_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
ushort2_t
*>
(
b
);
return
intrin_mfma_f32_16x16x2bf16
<
MPerXdlops
,
NPerXdlops
>
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
4
;
static
constexpr
index_t
m
=
4
;
static
constexpr
index_t
n
=
64
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
8
;
static
constexpr
index_t
k_base
=
2
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
ushort2_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
ushort2_t
*>
(
b
);
return
intrin_mfma_f32_4x4x2bf16
<
MPerXdlops
,
NPerXdlops
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
mfma_instr
instr
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
index_t
MRepeats_
,
index_t
NRepeats_
,
class
OutputVecType_
>
struct
xdlops_info
{
static
constexpr
auto
mfma_type
=
mfma_info
<
instr
>
{};
static
constexpr
index_t
MPerXdlops
=
MPerXdlops_
;
static
constexpr
index_t
NPerXdlops
=
NPerXdlops_
;
static
constexpr
index_t
MRepeats
=
MRepeats_
;
static
constexpr
index_t
NRepeats
=
NRepeats_
;
static
constexpr
bool
IsABroadcast
()
{
return
NPerXdlops
>=
MPerXdlops
;
}
static
constexpr
bool
IsKReduction
()
{
return
(
mfma_type
.
num_output_blks
==
1
)
&&
(
mfma_type
.
num_input_blks
>
1
);
}
static
constexpr
auto
OutputVecType
=
OutputVecType_
{};
};
template
<
class
data_type
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
>
struct
XdlopsGemm_t
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
__device__
static
constexpr
index_t
GetNumBlksPerXdlops
()
{
return
(
MPerXdlops
*
NPerXdlops
)
/
(
mfma_type
.
m
*
mfma_type
.
n
);
}
__device__
constexpr
XdlopsGemm_t
()
{
static_assert
(
NPerXdlops
==
4
||
NPerXdlops
==
8
||
NPerXdlops
==
16
||
NPerXdlops
==
32
||
NPerXdlops
==
64
,
"Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
MPerXdlops
==
4
||
MPerXdlops
==
8
||
MPerXdlops
==
16
||
MPerXdlops
==
32
||
MPerXdlops
==
64
,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
GemmDataPerReadA
==
1
&&
GemmDataPerReadB
==
1
,
"GemmDataPerReadA/B != 1"
);
static_assert
(
mfma_type
.
num_threads_blk
==
mfma_type
.
n
,
"n != num_threads_blk"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
num_input_blks
==
mfma_type
.
m
,
"m != num_input_blks * num_regs_blk"
);
static_assert
(
mfma_type
.
num_output_blks
==
mfma_type
.
num_input_blks
||
mfma_type
.
num_output_blks
==
1
,
"incorrect num_output_blks"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
wave_size
==
mfma_type
.
m
*
mfma_type
.
n
,
"num_regs_blk incorrect"
);
static_assert
(
mfma_type
.
k
%
mfma_type
.
k_base
==
0
,
"k and k_base is inconsistent!"
);
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
#if CK_USE_AMD_XDLOPS_EMULATE
// emulate xdlops
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
XdlopsEmulate
(
const
FloatA
*
const
__restrict__
p_a_wave
,
const
FloatB
*
const
__restrict__
p_b_wave
,
FloatC
p_c_thread
)
const
{
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
// K reduction
static_if
<
IsKReduction
>
{}([
&
](
auto
)
{
for
(
index_t
k
=
0
;
k
<
K
;
k
+=
mfma_type
.
num_input_blks
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
(
k
+
n
)
*
M
;
index_t
b_off
=
(
k
+
n
)
*
N
;
index_t
c_off
=
0
;
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
.
n
[
m
+
c_off
]
+=
inner_product_with_conversion
<
float
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}).
Else
([
&
](
auto
)
{
static_if
<
IsABroadcast
>
{}([
&
](
auto
)
{
for
(
index_t
m_i
=
0
;
m_i
<
MRepeats
;
++
m_i
)
{
for
(
index_t
n_i
=
0
;
n_i
<
NRepeats
;
++
n_i
)
{
// ABroadcast
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
MPerXdlops
/
mfma_type
.
m
;
++
b
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
k
*
M
+
b
*
mfma_type
.
m
+
MPerXdlops
*
m_i
;
index_t
b_off
=
k
*
N
+
n
*
mfma_type
.
num_threads_blk
+
NPerXdlops
*
n_i
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
+
(
NRepeats
*
m_i
+
n_i
)
*
GetRegSizePerXdlops
();
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
.
n
[
m
+
c_off
]
+=
inner_product_with_conversion
<
float
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
}
}
}).
Else
([
&
](
auto
)
{
// BBroadcast
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
NPerXdlops
/
mfma_type
.
n
;
++
b
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
k
*
M
+
n
*
mfma_type
.
m
;
index_t
b_off
=
k
*
N
+
b
*
mfma_type
.
n
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
;
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
.
n
[
m
+
c_off
]
+=
inner_product_with_conversion
<
float
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
});
});
return
p_c_thread
;
}
#endif
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
Run
(
const
FloatA
*
const
__restrict__
p_a_wave
,
const
FloatB
*
const
__restrict__
p_b_wave
,
FloatC
p_c_thread
)
const
{
static_assert
(
is_same
<
FloatA
,
FloatB
>::
value
,
"FloatA != FloatB"
);
static_assert
(
is_same
<
data_type
,
float
>::
value
||
is_same
<
data_type
,
half_t
>::
value
||
is_same
<
data_type
,
ushort
>::
value
,
"base data_type must be float, half, ushort!"
);
#if CK_USE_AMD_XDLOPS_EMULATE
p_c_thread
=
XdlopsEmulate
<
M
,
N
,
K
>
(
p_a_wave
,
p_b_wave
,
p_c_thread
);
#else
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
FloatA
a
[
K
*
MRepeats
];
FloatB
b
[
K
*
NRepeats
];
static_assert
(
sizeof
(
FloatA
)
%
(
sizeof
(
data_type
)
*
mfma_type
.
k_base
)
==
0
,
"wrong! FloatA is consistent with mfma"
);
static_assert
(
!
IsKReduction
||
K
%
mfma_type
.
num_input_blks
==
0
,
"K cannot divided by mfma_type.num_input_blks!"
);
static_assert
(
!
IsKReduction
||
(
MRepeats
==
1
&&
NRepeats
==
1
),
"KReduction does not support M/N Repeats!"
);
constexpr
index_t
KRepeats
=
sizeof
(
FloatA
)
/
(
sizeof
(
data_type
)
*
mfma_type
.
k_base
);
auto
pa
=
reinterpret_cast
<
const
data_type
*>
(
&
a
);
auto
pb
=
reinterpret_cast
<
const
data_type
*>
(
&
b
);
constexpr
index_t
AStride
=
K
*
KRepeats
;
constexpr
index_t
BStride
=
K
*
KRepeats
;
static_if
<!
IsKReduction
>
{}([
&
](
auto
)
{
for
(
index_t
m_i
=
0
;
m_i
<
MRepeats
;
++
m_i
)
for
(
index_t
k_i
=
0
;
k_i
<
K
;
++
k_i
)
a
[
k_i
+
m_i
*
K
]
=
p_a_wave
[
k_i
*
M
+
laneId
+
MPerXdlops
*
m_i
];
for
(
index_t
n_i
=
0
;
n_i
<
NRepeats
;
++
n_i
)
for
(
index_t
k_i
=
0
;
k_i
<
K
;
++
k_i
)
b
[
k_i
+
n_i
*
K
]
=
p_b_wave
[
k_i
*
N
+
laneId
+
NPerXdlops
*
n_i
];
#if CK_WORKAROUND_SWDEV_229564
#pragma unroll
#endif
for
(
index_t
k_i
=
0
;
k_i
<
K
*
KRepeats
;
++
k_i
)
{
p_c_thread
=
mfma_type
.
template
run
<
MPerXdlops
*
MRepeats
,
NPerXdlops
*
NRepeats
,
AStride
,
BStride
>(
&
pa
[
k_i
*
mfma_type
.
k_base
],
&
pb
[
k_i
*
mfma_type
.
k_base
],
p_c_thread
);
}
}).
Else
([
&
](
auto
)
{
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
// load into registers
for
(
index_t
k_i
=
0
;
k_i
<
K
;
k_i
+=
mfma_type
.
num_input_blks
)
{
a
[
k_i
]
=
p_a_wave
[(
k_i
+
blk_id
)
*
M
+
blk_td
];
b
[
k_i
]
=
p_b_wave
[(
k_i
+
blk_id
)
*
N
+
blk_td
];
}
#if CK_WORKAROUND_SWDEV_229564
#pragma unroll
#endif
for
(
index_t
k_i
=
0
;
k_i
<
K
;
k_i
+=
mfma_type
.
num_input_blks
)
{
for
(
index_t
i
=
0
;
i
<
KRepeats
;
++
i
)
p_c_thread
=
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>(
&
pa
[(
k_i
*
KRepeats
+
i
)
*
mfma_type
.
k_base
],
&
pb
[(
k_i
*
KRepeats
+
i
)
*
mfma_type
.
k_base
],
p_c_thread
);
}
});
#endif
return
p_c_thread
;
}
__device__
static
MatrixIndex
GetBeginOfThreadBlk
(
index_t
i
)
{
const
index_t
xdlops_i
=
i
/
GetNumBlksPerXdlops
();
const
index_t
j
=
i
%
GetNumBlksPerXdlops
();
const
index_t
m_i
=
xdlops_i
/
NRepeats
;
const
index_t
n_i
=
xdlops_i
%
NRepeats
;
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
index_t
col_blk
=
j
%
mfma_type
.
num_output_blks
;
index_t
row_blk
=
j
/
mfma_type
.
num_output_blks
;
static_if
<!
IsABroadcast
>
{}([
&
](
auto
)
{
col_blk
=
j
/
mfma_type
.
num_output_blks
;
row_blk
=
j
%
mfma_type
.
num_output_blks
;
});
index_t
col
=
col_blk
*
mfma_type
.
n
+
blk_td
+
n_i
*
NPerXdlops
;
index_t
row
=
row_blk
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
+
m_i
*
MPerXdlops
;
return
MatrixIndex
{
row
,
col
};
}
__device__
void
SetZeroXdlopsRegs
()
const
{}
template
<
class
FloatC
>
__device__
void
ReadXdlopsRegs
(
FloatC
*
const
__restrict__
)
const
{
}
template
<
class
data_type_
=
data_type
,
index_t
MPerWave_
=
GemmMPerWave
,
index_t
NPerWave_
=
GemmNPerWave
>
static
constexpr
auto
GetXdlopsInfo
();
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
128
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
2
,
1
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
128
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
2
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
1
,
c_vec32_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
32
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
32
,
64
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
64
,
16
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
16
,
64
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
8
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
8
,
64
,
1
,
1
,
c_vec4_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
4
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
4
,
64
,
1
,
1
,
c_vec4_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
,
32
,
32
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
,
16
,
16
,
1
,
1
,
c_vec4_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
128
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
2
,
1
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
128
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
1
,
2
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
1
,
1
,
c_vec32_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
32
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
32
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
32
,
64
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4f16
,
64
,
16
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
16
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4f16
,
16
,
64
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
8
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x4f16
,
8
,
64
,
1
,
1
,
c_vec4_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
4
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x4f16
,
4
,
64
,
1
,
1
,
c_vec4_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
32
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x8f16
,
32
,
32
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
16
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x16f16
,
16
,
16
,
1
,
1
,
c_vec4_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
128
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
2
,
1
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
128
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
1
,
2
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
1
,
1
,
c_vec32_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
32
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
32
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
32
,
64
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
,
64
,
16
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
16
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
,
16
,
64
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
8
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
,
8
,
64
,
1
,
1
,
c_vec4_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
4
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
,
4
,
64
,
1
,
1
,
c_vec4_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
32
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4bf16
,
32
,
32
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
16
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x8bf16
,
16
,
16
,
1
,
1
,
c_vec4_1_t
>
{};
}
static
constexpr
index_t
MRepeats
=
GetXdlopsInfo
().
MRepeats
;
static
constexpr
index_t
NRepeats
=
GetXdlopsInfo
().
NRepeats
;
static
constexpr
index_t
MPerXdlops
=
GetXdlopsInfo
().
MPerXdlops
;
static
constexpr
index_t
NPerXdlops
=
GetXdlopsInfo
().
NPerXdlops
;
static
constexpr
bool
IsKReduction
=
GetXdlopsInfo
().
IsKReduction
();
static
constexpr
bool
IsABroadcast
=
GetXdlopsInfo
().
IsABroadcast
();
static
constexpr
auto
mfma_type
=
GetXdlopsInfo
().
mfma_type
;
struct
OutputLayout
{
__device__
static
constexpr
index_t
M1
()
{
return
mfma_type
.
num_groups_blk
;
}
__device__
static
constexpr
index_t
M0
()
{
return
mfma_type
.
group_size
;
}
__device__
static
constexpr
index_t
N1
()
{
return
mfma_type
.
num_input_blks
;
}
__device__
static
constexpr
index_t
N0
()
{
return
mfma_type
.
num_threads_blk
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
mfma_type
.
num_regs_blk
;
}
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
GetNumBlksPerXdlops
()
*
MRepeats
*
NRepeats
;
}
__device__
static
constexpr
auto
CreateOutputVecZero
()
{
return
GetXdlopsInfo
().
OutputVecType
.
CreateVecZero
();
}
};
__device__
static
constexpr
auto
GetOutputLayout
()
{
return
OutputLayout
{};
}
};
}
// namespace ck
#endif
composable_kernel/include/utility/amd_xdlops.hpp
0 → 100644
View file @
87a75734
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "float_type.hpp"
namespace
ck
{
// A, B, C, cbsz, abid, blgp
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
float
,
float
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x1f32"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
float
,
float
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x2f32"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
float
,
float
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x4f32"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
float
,
float
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x1f32"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
float
,
float
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x1f32"
);
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
half4_t
,
half4_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x4f16"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_32x32x8f16
(
half4_t
,
half4_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x8f16"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_16x16x16f16
(
half4_t
,
half4_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x16f16"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
half4_t
,
half4_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x4f16"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
half4_t
,
half4_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x4f16"
);
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
ushort2_t
,
ushort2_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x2bf16"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_32x32x4bf16
(
ushort2_t
,
ushort2_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x4bf16"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_16x16x8bf16
(
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x8bf16"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_16x16x2bf16
(
ushort2_t
,
ushort2_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x2bf16"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
;
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
128
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
128
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_2_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
32
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
1
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
32
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
return
reg_c
;
}
};
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_32x32x2f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
__device__
c_vec4_1_t
::
VecType
intrin_mfma_f32_16x16x4f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x1f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
);
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x1f32
<
16
,
64
>
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
2
,
0
,
0
);
return
reg_c
;
}
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x1f32
<
64
,
16
>
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
4
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_4x4x1f32
;
template
<
>
struct
intrin_mfma_f32_4x4x1f32
<
4
,
64
>
{
__device__
static
c_vec4_1_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
return
reg_c
;
}
};
template
<
>
struct
intrin_mfma_f32_4x4x1f32
<
8
,
64
>
{
__device__
static
c_vec4_2_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec4_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
4
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
;
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
128
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
128
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_2_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
32
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
1
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
32
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
return
reg_c
;
}
};
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_32x32x8f16
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x8f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
__device__
c_vec4_1_t
::
VecType
intrin_mfma_f32_16x16x16f16
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x16f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x4f16
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
);
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x4f16
<
16
,
64
>
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
2
,
0
,
0
);
return
reg_c
;
}
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x4f16
<
64
,
16
>
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
4
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_4x4x4f16
;
template
<
>
struct
intrin_mfma_f32_4x4x4f16
<
4
,
64
>
{
__device__
static
c_vec4_1_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
return
reg_c
;
}
};
template
<
>
struct
intrin_mfma_f32_4x4x4f16
<
8
,
64
>
{
__device__
static
c_vec4_2_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec4_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
4
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
;
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
<
128
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
<
64
,
128
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
<
64
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_2_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec32_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
<
64
,
32
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
1
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
<
32
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
return
reg_c
;
}
};
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_32x32x4bf16
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
__device__
c_vec4_1_t
::
VecType
intrin_mfma_f32_16x16x8bf16
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x8bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x2bf16
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
);
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x2bf16
<
16
,
64
>
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
2
,
0
,
0
);
return
reg_c
;
}
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x2bf16
<
64
,
16
>
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
4
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_4x4x2bf16
;
template
<
>
struct
intrin_mfma_f32_4x4x2bf16
<
4
,
64
>
{
__device__
static
c_vec4_1_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
return
reg_c
;
}
};
template
<
>
struct
intrin_mfma_f32_4x4x2bf16
<
8
,
64
>
{
__device__
static
c_vec4_2_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec4_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
4
,
1
,
0
);
return
reg_c
;
}
};
}
#endif
driver/include/gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp
0 → 100644
View file @
87a75734
#include "common_header.hpp"
#include "gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "float_types.h"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
// read params: problem description
constexpr
index_t
G
=
CK_PARAM_PROBLEM_G
;
constexpr
index_t
N
=
CK_PARAM_PROBLEM_N
;
constexpr
index_t
K
=
CK_PARAM_PROBLEM_K
;
constexpr
index_t
C
=
CK_PARAM_PROBLEM_C
;
constexpr
index_t
Hi
=
CK_PARAM_PROBLEM_HI
;
constexpr
index_t
Wi
=
CK_PARAM_PROBLEM_WI
;
constexpr
index_t
Ho
=
CK_PARAM_PROBLEM_HO
;
constexpr
index_t
Wo
=
CK_PARAM_PROBLEM_WO
;
constexpr
index_t
Y
=
CK_PARAM_PROBLEM_Y
;
constexpr
index_t
X
=
CK_PARAM_PROBLEM_X
;
constexpr
index_t
ConvStrideH
=
CK_PARAM_PROBLEM_CONV_STRIDE_H
;
constexpr
index_t
ConvStrideW
=
CK_PARAM_PROBLEM_CONV_STRIDE_W
;
constexpr
index_t
ConvDilationH
=
CK_PARAM_PROBLEM_CONV_DILATION_H
;
constexpr
index_t
ConvDilationW
=
CK_PARAM_PROBLEM_CONV_DILATION_W
;
constexpr
index_t
InLeftPadH
=
CK_PARAM_PROBLEM_IN_LEFT_PAD_H
;
constexpr
index_t
InLeftPadW
=
CK_PARAM_PROBLEM_IN_LEFT_PAD_W
;
constexpr
index_t
InRightPadH
=
CK_PARAM_PROBLEM_IN_RIGHT_PAD_H
;
constexpr
index_t
InRightPadW
=
CK_PARAM_PROBLEM_IN_RIGHT_PAD_W
;
constexpr
auto
CPerGroup
=
C
/
G
;
constexpr
auto
in_n_c_hi_wi_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
C
,
Hi
,
Wi
>
{});
constexpr
auto
wei_k_cpergroup_y_x_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
K
,
CPerGroup
,
Y
,
X
>
{});
constexpr
auto
out_n_k_ho_wo_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
K
,
Ho
,
Wo
>
{});
using
ConvStrides
=
Sequence
<
ConvStrideH
,
ConvStrideW
>
;
using
ConvDilations
=
Sequence
<
ConvDilationH
,
ConvDilationW
>
;
using
InLeftPads
=
Sequence
<
InLeftPadH
,
InLeftPadW
>
;
using
InRightPads
=
Sequence
<
InRightPadH
,
InRightPadW
>
;
// read params: tunning parameters
constexpr
index_t
GemmMPerBlock
=
CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK
;
constexpr
index_t
GemmNPerBlock
=
CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK
;
constexpr
index_t
GemmKPerBlock
=
CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK
;
constexpr
index_t
GemmMPerWave
=
CK_PARAM_TUNABLE_GEMM_M_PER_WAVE
;
constexpr
index_t
GemmNPerWave
=
CK_PARAM_TUNABLE_GEMM_N_PER_WAVE
;
constexpr
index_t
GemmKPack
=
CK_PARAM_TUNABLE_GEMM_KPACK
;
// read params: dependent parameters
constexpr
index_t
BlockSize
=
CK_PARAM_DEPENDENT_BLOCK_SIZE
;
constexpr
index_t
GridSize
=
CK_PARAM_DEPENDENT_GRID_SIZE
;
// A matrix copy
constexpr
index_t
GemmABlockCopyClusterLengths_GemmK
=
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K
;
constexpr
index_t
GemmABlockCopyClusterLengths_GemmM
=
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M
;
constexpr
index_t
GemmABlockCopyClusterLengths_GemmKPack
=
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK
;
constexpr
index_t
GemmABlockCopyThreadSliceLengths_GemmK
=
GemmKPerBlock
/
GemmABlockCopyClusterLengths_GemmK
;
constexpr
index_t
GemmABlockCopyThreadSliceLengths_GemmM
=
GemmMPerBlock
/
GemmABlockCopyClusterLengths_GemmM
;
constexpr
index_t
GemmABlockCopyThreadSliceLengths_GemmKPack
=
GemmKPack
/
GemmABlockCopyClusterLengths_GemmKPack
;
using
GemmABlockCopyClusterLengths_GemmG_GemmK_GemmM_GemmKPack
=
Sequence
<
1
,
GemmABlockCopyClusterLengths_GemmK
,
GemmABlockCopyClusterLengths_GemmM
,
GemmABlockCopyClusterLengths_GemmKPack
>
;
using
GemmABlockCopySubLengths_GemmG_GemmK_GemmM_GemmKPack
=
Sequence
<
1
,
GemmABlockCopyThreadSliceLengths_GemmK
,
GemmABlockCopyThreadSliceLengths_GemmM
,
GemmABlockCopyThreadSliceLengths_GemmKPack
>
;
using
GemmABlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [GemmG, GemmM, GemmK, GemmKPack]
using
GemmABlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [GemmG, GemmM, GemmK, GemmKPack]
using
GemmABlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [GemmG, GemmK, GemmM, GemmKPack]
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmKPack
=
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_KPACK
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmKPack
=
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK
;
// B matrix Copy
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmK
=
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmN
=
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmKPack
=
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK
;
constexpr
index_t
GemmBBlockCopyThreadSliceLengths_GemmK
=
GemmKPerBlock
/
GemmBBlockCopyClusterLengths_GemmK
;
constexpr
index_t
GemmBBlockCopyThreadSliceLengths_GemmN
=
GemmNPerBlock
/
GemmBBlockCopyClusterLengths_GemmN
;
constexpr
index_t
GemmBBlockCopyThreadSliceLengths_GemmKPack
=
GemmKPack
/
GemmBBlockCopyClusterLengths_GemmKPack
;
using
GemmBBlockCopyClusterLengths_GemmG_GemmK_GemmN_GemmKPack
=
Sequence
<
1
,
GemmBBlockCopyClusterLengths_GemmK
,
GemmBBlockCopyClusterLengths_GemmN
,
GemmBBlockCopyClusterLengths_GemmKPack
>
;
using
GemmBBlockCopySubLengths_GemmG_GemmK_GemmN_GemmKPack
=
Sequence
<
1
,
GemmBBlockCopyThreadSliceLengths_GemmK
,
GemmBBlockCopyThreadSliceLengths_GemmN
,
GemmBBlockCopyThreadSliceLengths_GemmKPack
>
;
using
GemmBBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [GemmG, GemmK, GemmKPack, GemmN]
using
GemmBBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [GemmG, GemmK, GemmKPack, GemmN]
using
GemmBBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [GemmG, GemmK, GemmN, GemmKPack]
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPack
=
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK
;
// gridwise GEMM
constexpr
auto
wkgrp_schd_order
=
NBlock1MBlock0
;
constexpr
auto
gridwise_conv
=
GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
FLOAT
,
// Input data type
FLOAT_ACCUM
,
// Acc data type
FLOAT
,
// Ouput data type
decltype
(
in_n_c_hi_wi_desc
),
decltype
(
wei_k_cpergroup_y_x_desc
),
decltype
(
out_n_k_ho_wo_desc
),
G
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPack
,
GemmABlockCopySubLengths_GemmG_GemmK_GemmM_GemmKPack
,
GemmABlockCopyClusterLengths_GemmG_GemmK_GemmM_GemmKPack
,
GemmABlockCopyThreadClusterArrangeOrder
,
GemmABlockCopySrcAccessOrder
,
GemmABlockCopyDstAccessOrder
,
GemmABlockCopySrcDataPerRead_GemmKPack
,
GemmABlockCopyDstDataPerWrite_GemmKPack
,
GemmBBlockCopySubLengths_GemmG_GemmK_GemmN_GemmKPack
,
GemmBBlockCopyClusterLengths_GemmG_GemmK_GemmN_GemmKPack
,
GemmBBlockCopyThreadClusterArrangeOrder
,
GemmBBlockCopySrcAccessOrder
,
GemmBBlockCopyDstAccessOrder
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmKPack
,
wkgrp_schd_order
>
{};
gridwise_conv
.
Run
(
p_in_global
,
p_wei_global
,
p_out_global
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment