Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3973caa4
Commit
3973caa4
authored
Feb 07, 2023
by
illsilin
Browse files
switch between intrinsic mfma routines on mi100/200 and mi300
parent
dc58fa9a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
0 deletions
+76
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+55
-0
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+21
-0
No files found.
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
3973caa4
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
namespace
ck
{
namespace
ck
{
#if (defined(__gfx908__) || defined(__gfx90a__))
enum
struct
MfmaInstr
enum
struct
MfmaInstr
{
{
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_32x32x1xf32
=
0
,
...
@@ -29,6 +30,28 @@ enum struct MfmaInstr
...
@@ -29,6 +30,28 @@ enum struct MfmaInstr
mfma_i32_16x16x16i8
,
mfma_i32_16x16x16i8
,
mfma_f64_16x16x4f64
mfma_f64_16x16x4f64
};
};
#elif (defined(__gfx940__))
enum
struct
MfmaInstr
{
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_16x16x1xf32
,
mfma_f32_4x4x1xf32
,
mfma_f32_32x32x2xf32
,
mfma_f32_16x16x4xf32
,
mfma_f32_32x32x4f16
,
mfma_f32_16x16x4f16
,
mfma_f32_4x4x4f16
,
mfma_f32_32x32x8f16
,
mfma_f32_16x16x16f16
,
mfma_f32_32x32x8bf16_1k
,
mfma_f32_16x16x16bf16_1k
,
mfma_f32_32x32x4bf16
,
mfma_f32_16x16x8bf16
,
mfma_i32_32x32x16i8
,
mfma_i32_16x16x16i8
,
mfma_f64_16x16x4f64
};
#endif
template
<
MfmaInstr
instr
>
template
<
MfmaInstr
instr
>
struct
mfma_type
;
struct
mfma_type
;
...
@@ -342,6 +365,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
...
@@ -342,6 +365,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
}
}
};
};
#if (defined(__gfx908__) || defined(__gfx90a__))
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_32x32x8i8
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_32x32x8i8
>
{
{
...
@@ -363,6 +387,29 @@ struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8>
...
@@ -363,6 +387,29 @@ struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8>
intrin_mfma_i32_32x32x8i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_i32_32x32x8i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#elif (defined(__gfx940__))
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_32x32x16i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_32x32x16i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_16x16x16i8
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_16x16x16i8
>
...
@@ -524,11 +571,19 @@ struct MfmaSelector
...
@@ -524,11 +571,19 @@ struct MfmaSelector
#endif
#endif
}
}
#if (defined(__gfx908__) || defined(__gfx90a__))
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
{
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
}
}
#elif (defined(__gfx940__))
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x16i8
;
}
#endif
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
...
...
include/ck/utility/amd_xdlops.hpp
View file @
3973caa4
...
@@ -259,6 +259,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
...
@@ -259,6 +259,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
}
}
};
};
#if (defined(__gfx908__) || defined(__gfx90a__))
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x8i8
;
struct
intrin_mfma_i32_32x32x8i8
;
...
@@ -277,6 +278,26 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
...
@@ -277,6 +278,26 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
0
);
0
);
}
}
};
};
#elif (defined(__gfx940__))
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x16i8
;
template
<
>
struct
intrin_mfma_i32_32x32x16i8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x4_t
&
reg_a
,
const
int8x4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_32x32x16_i8
(
bit_cast
<
int32_t
>
(
reg_a
),
bit_cast
<
int32_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
#endif
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_16x16x16i8
;
struct
intrin_mfma_i32_16x16x16i8
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment