Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7084b152
Commit
7084b152
authored
May 13, 2021
by
Jing Zhang
Browse files
working on blockwise_gemm_xdlops
parent
be49a8c5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1872 additions
and
18 deletions
+1872
-18
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+180
-0
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+17
-18
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+1060
-0
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+615
-0
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
0 → 100644
View file @
7084b152
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "xdlops_gemm.hpp"
#include "threadwise_gemm.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
class
ABlockDesc
,
class
BBlockDesc
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmKPerWave
,
index_t
GemmMWaves
,
index_t
GemmNWaves
,
index_t
GemmDataPerReadA
,
// \todo unused parameter, remove
index_t
GemmDataPerReadB
// \todo unused parameter, remove
>
struct
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
XdlopsGemm
=
XdlopsGemm_t
<
float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
index_t
mMyWaveOffsetA
;
index_t
mMyWaveOffsetB
;
static
constexpr
index_t
WaveSize
=
64
;
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm
.
GetOutputLayout
();
}
__device__
constexpr
auto
GetNumBlks
()
const
{
return
XdlopsGemm
.
GetOutputLayout
().
GetNumBlks
();
}
__device__
constexpr
auto
GetBlkSize
()
const
{
return
XdlopsGemm
.
GetOutputLayout
().
GetBlkSize
();
}
__device__
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
()
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
constexpr
index_t
M
=
ABlockDesc
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BBlockDesc
{}.
GetLength
(
I1
);
static_assert
(
GemmMPerWave
*
GemmMWaves
==
M
,
"GemmMWaves * GemmMPerWave != M"
);
static_assert
(
GemmNPerWave
*
GemmNWaves
==
N
,
"GemmNWaves * GemmNPerWave != N"
);
static_assert
(
BlockSize
==
GemmMWaves
*
GemmNWaves
*
WaveSize
,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize
\n
"
);
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
index_t
waveId_m
=
waveId
/
GemmNWaves
;
const
index_t
waveId_n
=
waveId
%
GemmNWaves
;
mMyWaveOffsetA
=
waveId_m
*
GemmMPerWave
;
mMyWaveOffsetB
=
waveId_n
*
GemmNPerWave
;
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
#if 0
constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_),
decltype(b_thread_desc_),
CThreadDesc,
Sequence<GemmKPerWave>,
Sequence<M0_, M1PerThread>,
Sequence<N0_, N1PerThread>>{};
constexpr index_t K = ABlockDesc{}.GetLength(I0);
static_for<0, K, GemmKPerWave>{}([&](auto k) {
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
});
#endif
}
template
<
index_t
AStride
=
GemmMPerWave
,
index_t
BStride
=
GemmNPerWave
>
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
i
)
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
XdlopsGemm
.
GetBeginOfThreadBlk
(
i
);
const
index_t
col
=
(
waveId
%
GemmNWaves
)
*
BStride
+
thread_mtx_on_blk
.
col
;
const
index_t
row
=
(
waveId
/
GemmNWaves
)
*
AStride
+
thread_mtx_on_blk
.
row
;
return
MatrixIndex
{
row
,
col
};
}
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
GemmKPerWave
>
{},
Number
<
1
>
{}));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
GemmKPerWave
>
{},
Number
<
1
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
GemmKPerWave
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
GemmKPerWave
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
>
;
// AThreadCopy a_thread_copy_;
// BThreadCopy b_thread_copy_;
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
7084b152
...
...
@@ -5,7 +5,7 @@
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_
v2
.hpp"
#include "blockwise_gemm_
xdlops
.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
...
...
@@ -313,23 +313,20 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MPerThread
>
{},
Number
<
NRepeat
>
{},
Number
<
NPerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k_m0_m1_block_desc
),
decltype
(
b_k_n0_n1_block_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
),
MPerThread
,
NPerThread
,
KPerThread
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
MPerThread
,
NPerThread
>
{};
const
auto
blockwise_gemm
=
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_desc
),
64
,
// MPerWave,
64
,
// NPerWave,
KPerBlock
,
2
,
// MWaves,
2
,
// NWaves,
1
,
// GemmDataPerReadM,
1
// GemmDataPerReadN
>
{};
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
...
...
@@ -477,6 +474,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
}
#if 0
// output: register to global memory
{
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
...
...
@@ -512,6 +510,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}
#endif
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
0 → 100644
View file @
7084b152
#ifndef CK_XDLOPS_GEMM_HPP
#define CK_XDLOPS_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "math.hpp"
#include "amd_xdlops.hpp"
namespace
ck
{
enum
struct
mfma_instr
{
// fp32
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_16x16x1xf32
,
mfma_f32_4x4x1xf32
,
mfma_f32_32x32x2xf32
,
// k reduction
mfma_f32_16x16x4xf32
,
// k reduction
// fp16
mfma_f32_32x32x4f16
,
mfma_f32_16x16x4f16
,
mfma_f32_4x4x4f16
,
mfma_f32_32x32x8f16
,
// k reduction
mfma_f32_16x16x16f16
,
// k reduction
// bfp16
mfma_f32_32x32x2bf16
,
mfma_f32_16x16x2bf16
,
mfma_f32_4x4x2bf16
,
mfma_f32_32x32x4bf16
,
// k reduction
mfma_f32_16x16x8bf16
,
// k reduction
};
template
<
mfma_instr
instr
>
struct
mfma_info
;
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_32x32x2f32
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_16x16x4f32
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_16x16x1f32
<
MPerXdlops
,
NPerXdlops
>
(
p_a
,
p_b
,
reg_c
);
}
};
// treat 4x4x1 as a single-blk 4x64 mfma
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
4
;
static
constexpr
index_t
m
=
4
;
static
constexpr
index_t
n
=
64
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
8
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_4x4x1f32
<
MPerXdlops
,
NPerXdlops
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
4
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_32x32x4f16
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x8f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
8
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
4
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_32x32x8f16
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x16f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
16
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
4
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_16x16x16f16
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
4
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_16x16x4f16
<
MPerXdlops
,
NPerXdlops
>
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
4
;
static
constexpr
index_t
m
=
4
;
static
constexpr
index_t
n
=
64
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
8
;
static
constexpr
index_t
k_base
=
4
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_4x4x4f16
<
MPerXdlops
,
NPerXdlops
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
2
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
ushort2_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
ushort2_t
*>
(
b
);
return
intrin_mfma_f32_32x32x2bf16
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
2
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
ushort2_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
ushort2_t
*>
(
b
);
return
intrin_mfma_f32_32x32x4bf16
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x8bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
8
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
2
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
ushort2_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
ushort2_t
*>
(
b
);
return
intrin_mfma_f32_16x16x8bf16
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
2
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
ushort2_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
ushort2_t
*>
(
b
);
return
intrin_mfma_f32_16x16x2bf16
<
MPerXdlops
,
NPerXdlops
>
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
4
;
static
constexpr
index_t
m
=
4
;
static
constexpr
index_t
n
=
64
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
8
;
static
constexpr
index_t
k_base
=
2
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
ushort2_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
ushort2_t
*>
(
b
);
return
intrin_mfma_f32_4x4x2bf16
<
MPerXdlops
,
NPerXdlops
>::
run
(
p_a
,
p_b
,
reg_c
);
}
};
template
<
mfma_instr
instr
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
index_t
MRepeats_
,
index_t
NRepeats_
,
class
OutputVecType_
>
struct
xdlops_info
{
static
constexpr
auto
mfma_type
=
mfma_info
<
instr
>
{};
static
constexpr
index_t
MPerXdlops
=
MPerXdlops_
;
static
constexpr
index_t
NPerXdlops
=
NPerXdlops_
;
static
constexpr
index_t
MRepeats
=
MRepeats_
;
static
constexpr
index_t
NRepeats
=
NRepeats_
;
static
constexpr
bool
IsABroadcast
()
{
return
NPerXdlops
>=
MPerXdlops
;
}
static
constexpr
bool
IsKReduction
()
{
return
(
mfma_type
.
num_output_blks
==
1
)
&&
(
mfma_type
.
num_input_blks
>
1
);
}
static
constexpr
auto
OutputVecType
=
OutputVecType_
{};
};
template
<
class
data_type
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
>
struct
XdlopsGemm_t
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
__device__
static
constexpr
index_t
GetNumBlksPerXdlops
()
{
return
(
MPerXdlops
*
NPerXdlops
)
/
(
mfma_type
.
m
*
mfma_type
.
n
);
}
__device__
constexpr
XdlopsGemm_t
()
{
static_assert
(
NPerXdlops
==
4
||
NPerXdlops
==
8
||
NPerXdlops
==
16
||
NPerXdlops
==
32
||
NPerXdlops
==
64
,
"Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
MPerXdlops
==
4
||
MPerXdlops
==
8
||
MPerXdlops
==
16
||
MPerXdlops
==
32
||
MPerXdlops
==
64
,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
GemmDataPerReadA
==
1
&&
GemmDataPerReadB
==
1
,
"GemmDataPerReadA/B != 1"
);
static_assert
(
mfma_type
.
num_threads_blk
==
mfma_type
.
n
,
"n != num_threads_blk"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
num_input_blks
==
mfma_type
.
m
,
"m != num_input_blks * num_regs_blk"
);
static_assert
(
mfma_type
.
num_output_blks
==
mfma_type
.
num_input_blks
||
mfma_type
.
num_output_blks
==
1
,
"incorrect num_output_blks"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
wave_size
==
mfma_type
.
m
*
mfma_type
.
n
,
"num_regs_blk incorrect"
);
static_assert
(
mfma_type
.
k
%
mfma_type
.
k_base
==
0
,
"k and k_base is inconsistent!"
);
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
#if CK_USE_AMD_XDLOPS_EMULATE
// emulate xdlops
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
XdlopsEmulate
(
const
FloatA
*
const
__restrict__
p_a_wave
,
const
FloatB
*
const
__restrict__
p_b_wave
,
FloatC
p_c_thread
)
const
{
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
// K reduction
static_if
<
IsKReduction
>
{}([
&
](
auto
)
{
for
(
index_t
k
=
0
;
k
<
K
;
k
+=
mfma_type
.
num_input_blks
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
(
k
+
n
)
*
M
;
index_t
b_off
=
(
k
+
n
)
*
N
;
index_t
c_off
=
0
;
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
.
n
[
m
+
c_off
]
+=
inner_product_with_conversion
<
float
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
})
.
Else
([
&
](
auto
)
{
static_if
<
IsABroadcast
>
{}([
&
](
auto
)
{
for
(
index_t
m_i
=
0
;
m_i
<
MRepeats
;
++
m_i
)
{
for
(
index_t
n_i
=
0
;
n_i
<
NRepeats
;
++
n_i
)
{
// ABroadcast
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
MPerXdlops
/
mfma_type
.
m
;
++
b
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
k
*
M
+
b
*
mfma_type
.
m
+
MPerXdlops
*
m_i
;
index_t
b_off
=
k
*
N
+
n
*
mfma_type
.
num_threads_blk
+
NPerXdlops
*
n_i
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
+
(
NRepeats
*
m_i
+
n_i
)
*
GetRegSizePerXdlops
();
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
.
n
[
m
+
c_off
]
+=
inner_product_with_conversion
<
float
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
}
}
})
.
Else
([
&
](
auto
)
{
// BBroadcast
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
NPerXdlops
/
mfma_type
.
n
;
++
b
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
k
*
M
+
n
*
mfma_type
.
m
;
index_t
b_off
=
k
*
N
+
b
*
mfma_type
.
n
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
;
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
.
n
[
m
+
c_off
]
+=
inner_product_with_conversion
<
float
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
});
});
return
p_c_thread
;
}
#endif
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
Run
(
const
FloatA
*
const
__restrict__
p_a_wave
,
const
FloatB
*
const
__restrict__
p_b_wave
,
FloatC
p_c_thread
)
const
{
static_assert
(
is_same
<
FloatA
,
FloatB
>::
value
,
"FloatA != FloatB"
);
static_assert
(
is_same
<
data_type
,
float
>::
value
||
is_same
<
data_type
,
half_t
>::
value
||
is_same
<
data_type
,
ushort
>::
value
,
"base data_type must be float, half, ushort!"
);
#if CK_USE_AMD_XDLOPS_EMULATE
p_c_thread
=
XdlopsEmulate
<
M
,
N
,
K
>
(
p_a_wave
,
p_b_wave
,
p_c_thread
);
#else
constexpr
index_t
KPACT
=
sizeof
(
FloatA
)
/
sizeof
(
data_type
);
static_assert
(
KPACT
%
mfma_type
.
k_base
==
0
,
"wrong! KPACT is not supported by mfma"
);
constexpr
index_t
KRepeats
=
KPACT
/
mfma_type
.
k_base
;
static_assert
(
!
IsKReduction
||
K
%
mfma_type
.
num_input_blks
==
0
,
"K cannot divided by mfma_type.num_input_blks!"
);
constexpr
index_t
KPerThread
=
IsKReduction
?
K
/
mfma_type
.
num_input_blks
:
K
;
static_assert
(
!
IsKReduction
||
(
MRepeats
==
1
&&
NRepeats
==
1
),
"KReduction does not support M/N Repeats!"
);
FloatA
a
[
KPerThread
*
MRepeats
];
FloatB
b
[
KPerThread
*
NRepeats
];
auto
pa
=
reinterpret_cast
<
const
data_type
*>
(
&
a
);
auto
pb
=
reinterpret_cast
<
const
data_type
*>
(
&
b
);
constexpr
index_t
AStride
=
KPerThread
*
KRepeats
;
constexpr
index_t
BStride
=
KPerThread
*
KRepeats
;
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
static_if
<!
IsKReduction
>
{}([
&
](
auto
)
{
for
(
index_t
m_i
=
0
;
m_i
<
MRepeats
;
++
m_i
)
for
(
index_t
k_i
=
0
;
k_i
<
KPerThread
;
++
k_i
)
a
[
k_i
+
m_i
*
KPerThread
]
=
p_a_wave
[
k_i
*
M
+
laneId
+
MPerXdlops
*
m_i
];
for
(
index_t
n_i
=
0
;
n_i
<
NRepeats
;
++
n_i
)
for
(
index_t
k_i
=
0
;
k_i
<
KPerThread
;
++
k_i
)
b
[
k_i
+
n_i
*
KPerThread
]
=
p_b_wave
[
k_i
*
N
+
laneId
+
NPerXdlops
*
n_i
];
})
.
Else
([
&
](
auto
)
{
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
for
(
index_t
k_i
=
0
;
k_i
<
KPerThread
;
++
k_i
)
{
a
[
k_i
]
=
p_a_wave
[(
k_i
*
mfma_type
.
num_input_blks
+
blk_id
)
*
M
+
blk_td
];
b
[
k_i
]
=
p_b_wave
[(
k_i
*
mfma_type
.
num_input_blks
+
blk_id
)
*
N
+
blk_td
];
}
});
#if CK_WORKAROUND_SWDEV_229564
#pragma unroll
#endif
for
(
index_t
k_i
=
0
;
k_i
<
KPerThread
*
KRepeats
;
++
k_i
)
{
p_c_thread
=
mfma_type
.
template
run
<
MPerXdlops
*
MRepeats
,
NPerXdlops
*
NRepeats
,
AStride
,
BStride
>(
&
pa
[
k_i
*
mfma_type
.
k_base
],
&
pb
[
k_i
*
mfma_type
.
k_base
],
p_c_thread
);
}
#endif
return
p_c_thread
;
}
__device__
static
MatrixIndex
GetBeginOfThreadBlk
(
index_t
i
)
{
const
index_t
xdlops_i
=
i
/
GetNumBlksPerXdlops
();
const
index_t
j
=
i
%
GetNumBlksPerXdlops
();
const
index_t
m_i
=
xdlops_i
/
NRepeats
;
const
index_t
n_i
=
xdlops_i
%
NRepeats
;
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
index_t
col_blk
=
j
%
mfma_type
.
num_output_blks
;
index_t
row_blk
=
j
/
mfma_type
.
num_output_blks
;
static_if
<!
IsABroadcast
>
{}([
&
](
auto
)
{
col_blk
=
j
/
mfma_type
.
num_output_blks
;
row_blk
=
j
%
mfma_type
.
num_output_blks
;
});
index_t
col
=
col_blk
*
mfma_type
.
n
+
blk_td
+
n_i
*
NPerXdlops
;
index_t
row
=
row_blk
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
+
m_i
*
MPerXdlops
;
return
MatrixIndex
{
row
,
col
};
}
__device__
void
SetZeroXdlopsRegs
()
const
{}
template
<
class
FloatC
>
__device__
void
ReadXdlopsRegs
(
FloatC
*
const
__restrict__
)
const
{
}
template
<
class
data_type_
=
data_type
,
index_t
MPerWave_
=
GemmMPerWave
,
index_t
NPerWave_
=
GemmNPerWave
>
static
constexpr
auto
GetXdlopsInfo
();
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
128
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
2
,
1
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
128
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
2
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
64
,
1
,
1
,
c_vec32_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
64
,
32
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
,
32
,
64
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
64
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
64
,
16
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
,
16
,
64
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
8
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
8
,
64
,
1
,
1
,
c_vec4_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
4
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
,
4
,
64
,
1
,
1
,
c_vec4_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
32
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
,
32
,
32
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
float
,
16
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
,
16
,
16
,
1
,
1
,
c_vec4_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
128
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
2
,
1
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
128
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
1
,
2
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
,
1
,
1
,
c_vec32_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
32
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
32
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
32
,
64
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4f16
,
64
,
16
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
16
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4f16
,
16
,
64
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
8
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x4f16
,
8
,
64
,
1
,
1
,
c_vec4_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
4
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x4f16
,
4
,
64
,
1
,
1
,
c_vec4_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
32
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x8f16
,
32
,
32
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
16
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x16f16
,
16
,
16
,
1
,
1
,
c_vec4_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
128
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
2
,
1
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
128
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
1
,
2
,
c_vec32_4_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
64
,
1
,
1
,
c_vec32_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
64
,
32
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
32
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
,
32
,
64
,
1
,
1
,
c_vec32_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
64
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
,
64
,
16
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
16
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
,
16
,
64
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
8
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
,
8
,
64
,
1
,
1
,
c_vec4_2_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
4
,
64
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
,
4
,
64
,
1
,
1
,
c_vec4_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
32
,
32
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4bf16
,
32
,
32
,
1
,
1
,
c_vec16_1_t
>
{};
}
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
ushort
,
16
,
16
>
()
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x8bf16
,
16
,
16
,
1
,
1
,
c_vec4_1_t
>
{};
}
static
constexpr
index_t
MRepeats
=
GetXdlopsInfo
().
MRepeats
;
static
constexpr
index_t
NRepeats
=
GetXdlopsInfo
().
NRepeats
;
static
constexpr
index_t
MPerXdlops
=
GetXdlopsInfo
().
MPerXdlops
;
static
constexpr
index_t
NPerXdlops
=
GetXdlopsInfo
().
NPerXdlops
;
static
constexpr
bool
IsKReduction
=
GetXdlopsInfo
().
IsKReduction
();
static
constexpr
bool
IsABroadcast
=
GetXdlopsInfo
().
IsABroadcast
();
static
constexpr
auto
mfma_type
=
GetXdlopsInfo
().
mfma_type
;
struct
OutputLayout
{
__device__
static
constexpr
index_t
M1
()
{
return
mfma_type
.
num_groups_blk
;
}
__device__
static
constexpr
index_t
M0
()
{
return
mfma_type
.
group_size
;
}
__device__
static
constexpr
index_t
N1
()
{
return
mfma_type
.
num_input_blks
;
}
__device__
static
constexpr
index_t
N0
()
{
return
mfma_type
.
num_threads_blk
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
mfma_type
.
num_regs_blk
;
}
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
GetNumBlksPerXdlops
()
*
MRepeats
*
NRepeats
;
}
__device__
static
constexpr
auto
CreateOutputVecZero
()
{
return
GetXdlopsInfo
().
OutputVecType
.
CreateVecZero
();
}
};
__device__
static
constexpr
auto
GetOutputLayout
()
{
return
OutputLayout
{};
}
};
}
// namespace ck
#endif
composable_kernel/include/utility/amd_xdlops.hpp
0 → 100644
View file @
7084b152
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "float_type.hpp"
namespace
ck
{
struct
c_vec32_4_t
{
union
VecType
{
struct
{
float32_t
x
;
float32_t
y
;
float32_t
z
;
float32_t
w
;
}
s
;
float
n
[
128
];
};
__host__
__device__
static
VecType
CreateVecZero
()
{
VecType
c
;
c
.
s
.
x
=
0
;
c
.
s
.
y
=
0
;
c
.
s
.
z
=
0
;
c
.
s
.
w
=
0
;
return
c
;
}
};
struct
c_vec32_2_t
{
union
VecType
{
struct
{
float32_t
x
;
float32_t
y
;
}
s
;
float
n
[
64
];
}
l
;
__host__
__device__
static
VecType
CreateVecZero
()
{
VecType
c
;
c
.
s
.
x
=
0
;
c
.
s
.
y
=
0
;
return
c
;
}
};
struct
c_vec32_2_2_t
{
union
VecType
{
struct
{
c_vec32_2_t
x
;
c_vec32_2_t
y
;
}
s
;
float
n
[
128
];
};
__host__
__device__
static
VecType
CreateVecZero
()
{
VecType
c
;
c
.
s
.
x
.
l
.
s
.
x
=
0
;
c
.
s
.
x
.
l
.
s
.
y
=
0
;
c
.
s
.
y
.
l
.
s
.
x
=
0
;
c
.
s
.
y
.
l
.
s
.
y
=
0
;
return
c
;
}
};
struct
c_vec32_1_t
{
union
VecType
{
struct
{
float32_t
x
;
}
s
;
float
n
[
32
];
};
__host__
__device__
static
VecType
CreateVecZero
()
{
VecType
c
;
c
.
s
.
x
=
0
;
return
c
;
}
};
struct
c_vec16_1_t
{
union
VecType
{
struct
{
float16_t
x
;
}
s
;
float
n
[
16
];
};
__host__
__device__
static
VecType
CreateVecZero
()
{
VecType
c
;
c
.
s
.
x
=
0
;
return
c
;
}
};
struct
c_vec4_2_t
{
union
VecType
{
struct
{
float4_t
x
;
float4_t
y
;
}
s
;
float
n
[
8
];
};
__host__
__device__
static
VecType
CreateVecZero
()
{
VecType
c
;
c
.
s
.
x
=
0
;
c
.
s
.
y
=
0
;
return
c
;
}
};
struct
c_vec4_1_t
{
union
VecType
{
struct
{
float4_t
x
;
}
s
;
float
n
[
4
];
};
__host__
__device__
static
VecType
CreateVecZero
()
{
VecType
c
;
c
.
s
.
x
=
0
;
return
c
;
}
};
// A, B, C, cbsz, abid, blgp
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
float
,
float
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x1f32"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
float
,
float
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x2f32"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
float
,
float
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x4f32"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
float
,
float
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x1f32"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
float
,
float
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x1f32"
);
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
half4_t
,
half4_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x4f16"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_32x32x8f16
(
half4_t
,
half4_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x8f16"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_16x16x16f16
(
half4_t
,
half4_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x16f16"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
half4_t
,
half4_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x4f16"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
half4_t
,
half4_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x4f16"
);
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
ushort2_t
,
ushort2_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x2bf16"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_32x32x4bf16
(
ushort2_t
,
ushort2_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x4bf16"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_16x16x8bf16
(
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x8bf16"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_16x16x2bf16
(
ushort2_t
,
ushort2_t
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x2bf16"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
;
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
128
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
128
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_2_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
32
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
1
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
32
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
return
reg_c
;
}
};
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_32x32x2f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
__device__
c_vec4_1_t
::
VecType
intrin_mfma_f32_16x16x4f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x1f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
);
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x1f32
<
16
,
64
>
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
2
,
0
,
0
);
return
reg_c
;
}
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x1f32
<
64
,
16
>
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
4
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_4x4x1f32
;
template
<
>
struct
intrin_mfma_f32_4x4x1f32
<
4
,
64
>
{
__device__
static
c_vec4_1_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
return
reg_c
;
}
};
template
<
>
struct
intrin_mfma_f32_4x4x1f32
<
8
,
64
>
{
__device__
static
c_vec4_2_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec4_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
4
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
;
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
128
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
128
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_2_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
32
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
1
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
32
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
return
reg_c
;
}
};
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_32x32x8f16
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x8f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
__device__
c_vec4_1_t
::
VecType
intrin_mfma_f32_16x16x16f16
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x16f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x4f16
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
);
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x4f16
<
16
,
64
>
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
2
,
0
,
0
);
return
reg_c
;
}
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x4f16
<
64
,
16
>
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
4
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_4x4x4f16
;
template
<
>
struct
intrin_mfma_f32_4x4x4f16
<
4
,
64
>
{
__device__
static
c_vec4_1_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
return
reg_c
;
}
};
template
<
>
struct
intrin_mfma_f32_4x4x4f16
<
8
,
64
>
{
__device__
static
c_vec4_2_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec4_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
4
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
;
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
<
128
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
<
64
,
128
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
<
64
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_2_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec32_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
<
64
,
32
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
1
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x2bf16
<
32
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_1_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
return
reg_c
;
}
};
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_32x32x4bf16
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
__device__
c_vec4_1_t
::
VecType
intrin_mfma_f32_16x16x8bf16
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x8bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x2bf16
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
);
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x2bf16
<
16
,
64
>
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
2
,
0
,
0
);
return
reg_c
;
}
template
<
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x2bf16
<
64
,
16
>
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
4
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_4x4x2bf16
;
template
<
>
struct
intrin_mfma_f32_4x4x2bf16
<
4
,
64
>
{
__device__
static
c_vec4_1_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
return
reg_c
;
}
};
template
<
>
struct
intrin_mfma_f32_4x4x2bf16
<
8
,
64
>
{
__device__
static
c_vec4_2_t
::
VecType
run
(
const
ushort2_t
*
reg_a
,
const
ushort2_t
*
reg_b
,
c_vec4_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
4
,
1
,
0
);
return
reg_c
;
}
};
}
// namespace ck
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment