Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5dbaf3c2
Commit
5dbaf3c2
authored
Aug 18, 2021
by
Jing Zhang
Browse files
refactor xdlops, hide c desc
parent
370c9245
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
155 additions
and
171 deletions
+155
-171
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+8
-8
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+147
-163
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
5dbaf3c2
...
@@ -118,13 +118,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -118,13 +118,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2ThreadDescriptor
()
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2ThreadDescriptor
()
{
{
///\to-do: hide xdl clayout into xdlops-gemm
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
CXdlopsLayout
=
xdlops_gemm
.
GetCXdlopsLayout
();
constexpr
auto
M0
=
Number
<
CXdlopsLayout
.
M1
()
>
{};
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M2
=
Number
<
CXdlopsLayout
.
M0
()
>
{};
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
I
1
,
M2
,
I1
));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
M
1
,
M2
,
N
));
}
}
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2BlockDescriptor
()
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2BlockDescriptor
()
...
@@ -195,7 +196,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -195,7 +196,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
/
xdlops_gemm
.
KPerThread
>
{}([
&
](
auto
k0
)
{
// read A
// read A
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1_block_desc
,
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1_block_desc
,
make_tuple
(
k0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
k0
,
I0
,
I0
,
I0
,
I0
),
...
@@ -212,8 +213,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -212,8 +213,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
KPerThread
>::
type
;
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
mfma_type
.
k_per_blk
>::
type
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
5dbaf3c2
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
namespace
ck
{
namespace
ck
{
enum
struct
m
fma
_i
nstr
enum
struct
M
fma
I
nstr
{
{
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_16x16x1xf32
,
mfma_f32_16x16x1xf32
,
...
@@ -26,11 +26,11 @@ enum struct mfma_instr
...
@@ -26,11 +26,11 @@ enum struct mfma_instr
mfma_f32_16x16x8bf16
,
// k reduction
mfma_f32_16x16x8bf16
,
// k reduction
};
};
template
<
m
fma
_i
nstr
instr
>
template
<
M
fma
I
nstr
instr
>
struct
mfma_
info
;
struct
mfma_
type
;
template
<
>
template
<
>
struct
mfma_
info
<
m
fma
_i
nstr
::
mfma_f32_32x32x1xf32
>
struct
mfma_
type
<
M
fma
I
nstr
::
mfma_f32_32x32x1xf32
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
...
@@ -57,7 +57,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
...
@@ -57,7 +57,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
};
};
template
<
>
template
<
>
struct
mfma_
info
<
m
fma
_i
nstr
::
mfma_f32_32x32x2xf32
>
struct
mfma_
type
<
M
fma
I
nstr
::
mfma_f32_32x32x2xf32
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
...
@@ -84,7 +84,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
...
@@ -84,7 +84,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
};
};
template
<
>
template
<
>
struct
mfma_
info
<
m
fma
_i
nstr
::
mfma_f32_16x16x4xf32
>
struct
mfma_
type
<
M
fma
I
nstr
::
mfma_f32_16x16x4xf32
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
...
@@ -111,7 +111,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
...
@@ -111,7 +111,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
};
};
template
<
>
template
<
>
struct
mfma_
info
<
m
fma
_i
nstr
::
mfma_f32_16x16x1xf32
>
struct
mfma_
type
<
M
fma
I
nstr
::
mfma_f32_16x16x1xf32
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
...
@@ -139,7 +139,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
...
@@ -139,7 +139,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
// treat 4x4x1 as a single-blk 4x64 mfma
// treat 4x4x1 as a single-blk 4x64 mfma
template
<
>
template
<
>
struct
mfma_
info
<
m
fma
_i
nstr
::
mfma_f32_4x4x1xf32
>
struct
mfma_
type
<
M
fma
I
nstr
::
mfma_f32_4x4x1xf32
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
...
@@ -166,7 +166,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
...
@@ -166,7 +166,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
};
};
template
<
>
template
<
>
struct
mfma_
info
<
m
fma
_i
nstr
::
mfma_f32_32x32x4f16
>
struct
mfma_
type
<
M
fma
I
nstr
::
mfma_f32_32x32x4f16
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
...
@@ -193,7 +193,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
...
@@ -193,7 +193,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
};
};
template
<
>
template
<
>
struct
mfma_
info
<
m
fma
_i
nstr
::
mfma_f32_32x32x8f16
>
struct
mfma_
type
<
M
fma
I
nstr
::
mfma_f32_32x32x8f16
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
...
@@ -220,7 +220,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
...
@@ -220,7 +220,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
};
};
template
<
>
template
<
>
struct
mfma_
info
<
m
fma
_i
nstr
::
mfma_f32_16x16x16f16
>
struct
mfma_
type
<
M
fma
I
nstr
::
mfma_f32_16x16x16f16
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
...
@@ -247,7 +247,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
...
@@ -247,7 +247,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
};
};
template
<
>
template
<
>
struct
mfma_
info
<
m
fma
_i
nstr
::
mfma_f32_16x16x4f16
>
struct
mfma_
type
<
M
fma
I
nstr
::
mfma_f32_16x16x4f16
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
...
@@ -274,7 +274,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
...
@@ -274,7 +274,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
};
};
template
<
>
template
<
>
struct
mfma_
info
<
m
fma
_i
nstr
::
mfma_f32_4x4x4f16
>
struct
mfma_
type
<
M
fma
I
nstr
::
mfma_f32_4x4x4f16
>
{
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
...
@@ -302,7 +302,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
...
@@ -302,7 +302,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
#if 0
#if 0
template <>
template <>
struct mfma_
info<m
fma
_i
nstr::mfma_f32_32x32x2bf16>
struct mfma_
type<M
fma
I
nstr::mfma_f32_32x32x2bf16>
{
{
static constexpr index_t group_size = 4;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_groups_per_blk = 4;
...
@@ -334,7 +334,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
...
@@ -334,7 +334,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
};
};
template <>
template <>
struct mfma_
info<m
fma
_i
nstr::mfma_f32_32x32x4bf16>
struct mfma_
type<M
fma
I
nstr::mfma_f32_32x32x4bf16>
{
{
static constexpr index_t group_size = 4;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_groups_per_blk = 4;
...
@@ -365,7 +365,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
...
@@ -365,7 +365,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
};
};
template <>
template <>
struct mfma_
info<m
fma
_i
nstr::mfma_f32_16x16x8bf16>
struct mfma_
type<M
fma
I
nstr::mfma_f32_16x16x8bf16>
{
{
static constexpr index_t group_size = 4;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_groups_per_blk = 1;
...
@@ -396,7 +396,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
...
@@ -396,7 +396,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
};
};
template <>
template <>
struct mfma_
info<m
fma
_i
nstr::mfma_f32_16x16x2bf16>
struct mfma_
type<M
fma
I
nstr::mfma_f32_16x16x2bf16>
{
{
static constexpr index_t group_size = 4;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_groups_per_blk = 1;
...
@@ -427,7 +427,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
...
@@ -427,7 +427,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
};
};
template <>
template <>
struct mfma_
info<m
fma
_i
nstr::mfma_f32_4x4x2bf16>
struct mfma_
type<M
fma
I
nstr::mfma_f32_4x4x2bf16>
{
{
static constexpr index_t group_size = 4;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_groups_per_blk = 1;
...
@@ -458,229 +458,229 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
...
@@ -458,229 +458,229 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
};
};
#endif
#endif
template
<
mfma_instr
instr
,
index_t
MPerXdlops
_
,
index_t
NPerXdlops
_
>
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
struct
xdlops_info
struct
MfmaSelector
{
{
static
constexpr
auto
mfma_type
=
mfma_info
<
instr
>
{};
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
>
static
constexpr
auto
GetMfma
();
static
constexpr
index_t
MPerXdlops
=
MPerXdlops_
;
static
constexpr
index_t
NPerXdlops
=
NPerXdlops_
;
static
constexpr
bool
IsABroadcast
()
{
static_assert
(
NPerXdlops
>=
MPerXdlops
,
"only support ABroadcast"
);
return
true
;
}
static
constexpr
index_t
GetKPerXdlops
()
{
return
mfma_type
.
is_k_reduction
?
mfma_type
.
num_input_blks
:
1
;
}
static
constexpr
index_t
GetNumCRegs
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
};
template
<
class
base_type
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPack
>
struct
XdlopsGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
template
<
class
base_type_
=
base_type
,
index_t
MPerWave_
=
MPerWave
,
index_t
NPerWave_
=
NPerWave
>
static
constexpr
auto
GetXdlopsInfo
();
template
<
>
template
<
>
static
constexpr
auto
Get
XdlopsInfo
<
float
,
64
,
64
>
()
static
constexpr
auto
Get
Mfma
<
float
,
64
,
64
>
()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_32x32x1xf32
,
64
,
64
>
{}
;
return
M
fma
I
nstr
::
mfma_f32_32x32x1xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
Get
XdlopsInfo
<
float
,
32
,
64
>
()
static
constexpr
auto
Get
Mfma
<
float
,
32
,
64
>
()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_32x32x1xf32
,
32
,
64
>
{}
;
return
M
fma
I
nstr
::
mfma_f32_32x32x1xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
Get
XdlopsInfo
<
float
,
16
,
64
>
()
static
constexpr
auto
Get
Mfma
<
float
,
16
,
64
>
()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_16x16x1xf32
,
16
,
64
>
{}
;
return
M
fma
I
nstr
::
mfma_f32_16x16x1xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
Get
XdlopsInfo
<
float
,
8
,
64
>
()
static
constexpr
auto
Get
Mfma
<
float
,
8
,
64
>
()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_4x4x1xf32
,
8
,
64
>
{}
;
return
M
fma
I
nstr
::
mfma_f32_4x4x1xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
Get
XdlopsInfo
<
float
,
4
,
64
>
()
static
constexpr
auto
Get
Mfma
<
float
,
4
,
64
>
()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_4x4x1xf32
,
4
,
64
>
{}
;
return
M
fma
I
nstr
::
mfma_f32_4x4x1xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
Get
XdlopsInfo
<
float
,
32
,
32
>
()
static
constexpr
auto
Get
Mfma
<
float
,
32
,
32
>
()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_32x32x2xf32
,
32
,
32
>
{}
;
return
M
fma
I
nstr
::
mfma_f32_32x32x2xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
Get
XdlopsInfo
<
float
,
16
,
16
>
()
static
constexpr
auto
Get
Mfma
<
float
,
16
,
16
>
()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_16x16x4xf32
,
16
,
16
>
{}
;
return
M
fma
I
nstr
::
mfma_f32_16x16x4xf32
;
}
}
#if 0
template <>
template <>
static
constexpr
auto
Get
XdlopsInfo
<
half_t
,
64
,
64
>
()
static constexpr auto Get
Mfma
<half_t, 64, 64>()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_32x32x4f16
,
64
,
64
>
{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_32x32x4f16, 64, 64>{};
}
}
template <>
template <>
static
constexpr
auto
Get
XdlopsInfo
<
half_t
,
32
,
64
>
()
static constexpr auto Get
Mfma
<half_t, 32, 64>()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_32x32x4f16
,
32
,
64
>
{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_32x32x4f16, 32, 64>{};
}
}
template <>
template <>
static
constexpr
auto
Get
XdlopsInfo
<
half_t
,
32
,
32
>
()
static constexpr auto Get
Mfma
<half_t, 32, 32>()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_32x32x8f16
,
32
,
32
>
{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_32x32x8f16, 32, 32>{};
}
}
template <>
template <>
static
constexpr
auto
Get
XdlopsInfo
<
half_t
,
16
,
16
>
()
static constexpr auto Get
Mfma
<half_t, 16, 16>()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_16x16x16f16
,
16
,
16
>
{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_16x16x16f16, 16, 16>{};
}
}
template <>
template <>
static
constexpr
auto
Get
XdlopsInfo
<
half_t
,
16
,
64
>
()
static constexpr auto Get
Mfma
<half_t, 16, 64>()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_16x16x4f16
,
16
,
64
>
{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_16x16x4f16, 16, 64>{};
}
}
template <>
template <>
static
constexpr
auto
Get
XdlopsInfo
<
half_t
,
8
,
64
>
()
static constexpr auto Get
Mfma
<half_t, 8, 64>()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_4x4x4f16
,
8
,
64
>
{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_4x4x4f16, 8, 64>{};
}
}
template <>
template <>
static
constexpr
auto
Get
XdlopsInfo
<
half_t
,
4
,
64
>
()
static constexpr auto Get
Mfma
<half_t, 4, 64>()
{
{
return
xdlops_info
<
m
fma
_i
nstr
::
mfma_f32_4x4x4f16
,
4
,
64
>
{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_4x4x4f16, 4, 64>{};
}
}
#if 0
template <>
template <>
static constexpr auto Get
XdlopsInfo
<ushort, 128, 64>()
static constexpr auto Get
Mfma
<ushort, 128, 64>()
{
{
return xdlops_info<
m
fma
_i
nstr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
}
}
template <>
template <>
static constexpr auto Get
XdlopsInfo
<ushort, 64, 128>()
static constexpr auto Get
Mfma
<ushort, 64, 128>()
{
{
return xdlops_info<
m
fma
_i
nstr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
}
}
template <>
template <>
static constexpr auto Get
XdlopsInfo
<ushort, 64, 64>()
static constexpr auto Get
Mfma
<ushort, 64, 64>()
{
{
return xdlops_info<
m
fma
_i
nstr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
}
}
template <>
template <>
static constexpr auto Get
XdlopsInfo
<ushort, 64, 32>()
static constexpr auto Get
Mfma
<ushort, 64, 32>()
{
{
return xdlops_info<
m
fma
_i
nstr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
}
}
template <>
template <>
static constexpr auto Get
XdlopsInfo
<ushort, 32, 64>()
static constexpr auto Get
Mfma
<ushort, 32, 64>()
{
{
return xdlops_info<
m
fma
_i
nstr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
}
}
template <>
template <>
static constexpr auto Get
XdlopsInfo
<ushort, 64, 16>()
static constexpr auto Get
Mfma
<ushort, 64, 16>()
{
{
return xdlops_info<
m
fma
_i
nstr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
}
}
template <>
template <>
static constexpr auto Get
XdlopsInfo
<ushort, 16, 64>()
static constexpr auto Get
Mfma
<ushort, 16, 64>()
{
{
return xdlops_info<
m
fma
_i
nstr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
}
}
template <>
template <>
static constexpr auto Get
XdlopsInfo
<ushort, 8, 64>()
static constexpr auto Get
Mfma
<ushort, 8, 64>()
{
{
return xdlops_info<
m
fma
_i
nstr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
}
}
template <>
template <>
static constexpr auto Get
XdlopsInfo
<ushort, 4, 64>()
static constexpr auto Get
Mfma
<ushort, 4, 64>()
{
{
return xdlops_info<
m
fma
_i
nstr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
}
}
template <>
template <>
static constexpr auto Get
XdlopsInfo
<ushort, 32, 32>()
static constexpr auto Get
Mfma
<ushort, 32, 32>()
{
{
return xdlops_info<
m
fma
_i
nstr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
}
}
template <>
template <>
static constexpr auto Get
XdlopsInfo
<ushort, 16, 16>()
static constexpr auto Get
Mfma
<ushort, 16, 16>()
{
{
return xdlops_info<
m
fma
_i
nstr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
return xdlops_info<
M
fma
I
nstr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
}
}
#endif
#endif
using
CIndex
=
MultiIndex
<
2
>
;
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
mfma_type
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
(
mfma_type
.
m_per_blk
*
mfma_type
.
n_per_blk
*
mfma_type
.
num_output_blks
);
}
__host__
__device__
static
void
mfma_
info_
check
()
__host__
__device__
static
constexpr
void
mfma_check
()
{
{
static_assert
(
mfma_type
.
group_size
*
mfma_type
.
num_groups_per_blk
==
mfma_type
.
num_regs_per_blk
,
static_assert
(
selected_mfma
.
group_size
*
selected_mfma
.
num_groups_per_blk
==
selected_mfma
.
num_regs_per_blk
,
"wrong! num_regs_per_blk"
);
"wrong! num_regs_per_blk"
);
static_assert
(
mfma_type
.
num_threads_per_blk
==
mfma_type
.
n_per_blk
,
static_assert
(
selected_mfma
.
num_threads_per_blk
==
selected_mfma
.
n_per_blk
,
"n_per_blk != num_threads_per_blk"
);
"n_per_blk != num_threads_per_blk"
);
static_assert
(
mfma_type
.
num_regs_per_blk
*
mfma_type
.
num_input_blks
==
mfma_type
.
m_per_blk
,
static_assert
(
selected_mfma
.
num_regs_per_blk
*
selected_mfma
.
num_input_blks
==
selected_mfma
.
m_per_blk
,
"m_per_blk != num_input_blks * num_regs_per_blk"
);
"m_per_blk != num_input_blks * num_regs_per_blk"
);
static_assert
(
mfma_type
.
num_output_blks
==
mfma_type
.
num_input_blks
||
static_assert
(
selected_mfma
.
num_output_blks
==
selected_mfma
.
num_input_blks
||
mfma_type
.
num_output_blks
==
1
,
selected_mfma
.
num_output_blks
==
1
,
"incorrect num_output_blks"
);
"incorrect num_output_blks"
);
static_assert
(
mfma_type
.
num_regs_per_blk
*
mfma_type
.
wave_size
==
static_assert
(
selected_mfma
.
num_regs_per_blk
*
selected_mfma
.
wave_size
==
mfma_type
.
m_per_blk
*
mfma_type
.
n_per_blk
,
selected_mfma
.
m_per_blk
*
selected_mfma
.
n_per_blk
,
"num_regs_per_blk incorrect"
);
"num_regs_per_blk incorrect"
);
static_assert
(
mfma_type
.
is_k_reduction
||
static_assert
(
selected_mfma
.
is_k_reduction
||
(
mfma_type
.
num_input_blks
==
mfma_type
.
num_output_blks
),
(
selected_mfma
.
num_input_blks
==
selected_mfma
.
num_output_blks
),
"is_k_reduction wrong!"
);
"is_k_reduction wrong!"
);
}
}
__host__
__device__
constexpr
MfmaSelector
()
{
mfma_check
();
}
static
constexpr
bool
IsABroadcast
()
{
static_assert
(
NPerXdlops
>=
MPerXdlops
,
"only support ABroadcast"
);
return
true
;
}
static
constexpr
index_t
GetKPerXdlops
()
{
return
(
selected_mfma
.
is_k_reduction
?
selected_mfma
.
num_input_blks
:
1
)
*
selected_mfma
.
k_per_blk
;
}
static
constexpr
index_t
GetKPerThread
()
{
return
selected_mfma
.
k_per_blk
;
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
KPack
>
struct
XdlopsGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
mfma_instr
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
(
mfma_instr
.
m_per_blk
*
mfma_instr
.
n_per_blk
*
mfma_instr
.
num_output_blks
);
}
__host__
__device__
constexpr
XdlopsGemm
()
__host__
__device__
constexpr
XdlopsGemm
()
{
{
static_assert
(
NPerXdlops
==
4
||
NPerXdlops
==
8
||
NPerXdlops
==
16
||
NPerXdlops
==
32
||
static_assert
(
NPerXdlops
==
4
||
NPerXdlops
==
8
||
NPerXdlops
==
16
||
NPerXdlops
==
32
||
...
@@ -690,6 +690,8 @@ struct XdlopsGemm
...
@@ -690,6 +690,8 @@ struct XdlopsGemm
static_assert
(
MPerXdlops
==
4
||
MPerXdlops
==
8
||
MPerXdlops
==
16
||
MPerXdlops
==
32
||
static_assert
(
MPerXdlops
==
4
||
MPerXdlops
==
8
||
MPerXdlops
==
16
||
MPerXdlops
==
32
||
MPerXdlops
==
64
,
MPerXdlops
==
64
,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
KPack
%
mfma_instr
.
k_per_blk
==
0
,
"KPack cannot be divided by k_per_blk"
);
}
}
template
<
typename
CM0N0M1N1M2N2Desc
>
template
<
typename
CM0N0M1N1M2N2Desc
>
...
@@ -707,10 +709,10 @@ struct XdlopsGemm
...
@@ -707,10 +709,10 @@ struct XdlopsGemm
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
mfma_
type
.
num_groups_per_blk
,
make_unmerge_transform
(
make_tuple
(
mfma_
instr
.
num_groups_per_blk
,
mfma_
type
.
num_input_blks
,
mfma_
instr
.
num_input_blks
,
mfma_
type
.
group_size
)),
mfma_
instr
.
group_size
)),
make_pass_through_transform
(
mfma_
type
.
num_threads_per_blk
)),
make_pass_through_transform
(
mfma_
instr
.
num_threads_per_blk
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -727,7 +729,7 @@ struct XdlopsGemm
...
@@ -727,7 +729,7 @@ struct XdlopsGemm
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
{
return
MPerXdlops
*
NPerXdlops
/
mfma_
type
.
wave_size
;
return
MPerXdlops
*
NPerXdlops
/
mfma_
instr
.
wave_size
;
}
}
template
<
index_t
c_offset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
c_offset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
...
@@ -737,22 +739,20 @@ struct XdlopsGemm
...
@@ -737,22 +739,20 @@ struct XdlopsGemm
is_same
<
base_type
,
ushort
>::
value
,
is_same
<
base_type
,
ushort
>::
value
,
"base base_type must be float, half, ushort!"
);
"base base_type must be float, half, ushort!"
);
static_assert
(
KPack
%
mfma_type
.
k_per_blk
==
0
,
"KPack cannot be divided by k_per_blk"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
,
c_offset
>(
static_for
<
0
,
KPack
/
mfma_type
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
c_offset
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
});
});
}
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
mfma_
type
.
wave_size
;
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
mfma_
instr
.
wave_size
;
}
__device__
static
auto
GetBlkIdx
()
__device__
static
auto
GetBlkIdx
()
{
{
const
auto
laneId
=
GetLaneId
();
const
auto
laneId
=
GetLaneId
();
const
auto
threadidx_to_blk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
const
auto
threadidx_to_blk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
mfma_
type
.
num_input_blks
,
mfma_
type
.
num_threads_per_blk
))),
make_tuple
(
1
,
mfma_
instr
.
num_input_blks
,
mfma_
instr
.
num_threads_per_blk
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
...
@@ -773,7 +773,7 @@ struct XdlopsGemm
...
@@ -773,7 +773,7 @@ struct XdlopsGemm
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
mfma_
type
.
is_k_reduction
)
if
constexpr
(
mfma_
instr
.
is_k_reduction
)
{
{
return
make_tuple
(
blk_id
,
blk_td
);
return
make_tuple
(
blk_id
,
blk_td
);
}
}
...
@@ -791,7 +791,7 @@ struct XdlopsGemm
...
@@ -791,7 +791,7 @@ struct XdlopsGemm
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
mfma_
type
.
is_k_reduction
)
if
constexpr
(
mfma_
instr
.
is_k_reduction
)
{
{
return
make_tuple
(
blk_id
,
blk_td
);
return
make_tuple
(
blk_id
,
blk_td
);
}
}
...
@@ -803,45 +803,29 @@ struct XdlopsGemm
...
@@ -803,45 +803,29 @@ struct XdlopsGemm
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
{
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
const
auto
blk_td
=
blk_idx
[
I1
];
index_t
n_offset
=
blk_i
*
mfma_
type
.
n_per_blk
+
blk_td
;
index_t
n_offset
=
blk_i
*
mfma_
instr
.
n_per_blk
+
blk_td
;
index_t
m_offset
=
xdlops_i
*
mfma_
type
.
m_per_blk
+
blk_id
*
mfma_
type
.
group_size
;
index_t
m_offset
=
xdlops_i
*
mfma_
instr
.
m_per_blk
+
blk_id
*
mfma_
instr
.
group_size
;
return
CIndex
{
m_offset
,
n_offset
};
return
CIndex
{
m_offset
,
n_offset
};
}
}
static
constexpr
index_t
MPerXdlops
=
GetXdlopsInfo
().
MPerXdlops
;
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
>
{};
static
constexpr
index_t
NPerXdlops
=
GetXdlopsInfo
().
NPerXdlops
;
static
constexpr
index_t
KPerXdlops
=
GetXdlopsInfo
().
GetKPerXdlops
();
static
constexpr
bool
IsABroadca
st
=
GetXdlopsInfo
().
IsABroadcast
()
;
static
constexpr
auto
mfma_in
st
r
=
mfma
.
selected_mfma
;
static
constexpr
auto
mfma_type
=
GetXdlopsInfo
().
mfma_type
;
static
constexpr
auto
KPerXdlops
=
mfma
.
GetKPerXdlops
();
static
constexpr
auto
KPerThread
=
mfma
.
GetKPerThread
();
struct
CLayout
__host__
__device__
static
constexpr
auto
GetCM0M1M2NThreadBlkLengths
()
{
{
__host__
__device__
static
constexpr
index_t
M1
()
{
return
mfma_type
.
num_groups_per_blk
;
}
return
make_tuple
(
__host__
__device__
static
constexpr
index_t
M0
()
{
return
mfma_type
.
group_size
;
}
Number
<
mfma_instr
.
num_groups_per_blk
>
{},
I1
,
Number
<
mfma_instr
.
group_size
>
{},
I1
);
__host__
__device__
static
constexpr
index_t
N1
()
{
return
mfma_type
.
num_input_blks
;
}
}
__host__
__device__
static
constexpr
index_t
N0
()
{
return
mfma_type
.
num_threads_per_blk
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
mfma_type
.
num_regs_per_blk
;
}
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
mfma_type
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
(
mfma_type
.
m_per_blk
*
mfma_type
.
n_per_blk
*
mfma_type
.
num_output_blks
);
}
};
__host__
__device__
static
constexpr
auto
GetCXdlopsLayout
()
{
return
CLayout
{};
}
};
};
}
// namespace ck
}
// namespace ck
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment