Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4ea89209
Commit
4ea89209
authored
Jun 01, 2021
by
Jing Zhang
Browse files
add 32x32x8fp16
parent
822856e1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
25 deletions
+35
-25
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+12
-15
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+18
-5
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+5
-5
No files found.
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
4ea89209
...
...
@@ -227,17 +227,13 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_32x32x8f16
(
p_a
,
p_b
,
reg_c
);
intrin_mfma_f32_32x32x8f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -589,19 +585,18 @@ struct XdlopsGemm
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
>
{};
}
#if 0
template
<
>
static constexpr auto GetXdlopsInfo<half_t,
64, 32
>()
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
32
,
64
>
()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16,
64, 32, 1, 1, c_vec32_1_t
>{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
32
,
64
>
{};
}
template
<
>
static constexpr auto GetXdlopsInfo<half_t, 32,
64
>()
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
32
,
32
>
()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x
4
f16, 32,
64, 1, 1, c_vec32_1_t
>{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x
8
f16
,
32
,
32
>
{};
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 16>()
{
...
...
@@ -759,12 +754,14 @@ struct XdlopsGemm
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m0
,
n0
))
*
GetNumXdlops
();
static_for
<
0
,
KPack
/
mfma_type
.
k_base
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
,
mfma_type
.
k_base
>
{}([
&
](
auto
k
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
0
,
m0
,
0
,
k
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
0
,
n0
,
0
,
k
));
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
c_offset
>(
p_a_wave
[
Number
<
a_offset
>
{}],
p_b_wave
[
Number
<
b_offset
>
{}],
p_c_thread
);
p_a_wave
[
Number
<
a_offset
/
mfma_type
.
k_base
>
{}],
p_b_wave
[
Number
<
b_offset
/
mfma_type
.
k_base
>
{}],
p_c_thread
);
});
}
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
4ea89209
...
...
@@ -394,12 +394,25 @@ struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
}
};
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_32x32x8f16
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
struct
intrin_mfma_f32_32x32x8f16
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x8f16
<
32
,
32
,
COffset
>
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x8f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x8f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
__device__
c_vec4_1_t
::
VecType
intrin_mfma_f32_16x16x16f16
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
4ea89209
...
...
@@ -110,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPack
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmKPack
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
1
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment