Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e610402f
"...git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "8527c587b6247232a6577f89336dd2dc2bc3377a"
Commit
e610402f
authored
Jun 01, 2021
by
Jing Zhang
Browse files
add fp16 mfma
parent
4ea89209
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
88 additions
and
83 deletions
+88
-83
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+16
-39
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+67
-39
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+5
-5
No files found.
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
e610402f
...
@@ -256,17 +256,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
...
@@ -256,17 +256,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
COffset
,
index_t
BStride
,
class
FloatA
,
class
FloatA
,
class
FloatB
,
class
FloatB
,
class
FloatC
>
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
intrin_mfma_f32_16x16x16f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_16x16x16f16
(
p_a
,
p_b
,
reg_c
);
}
}
};
};
...
@@ -289,17 +285,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
...
@@ -289,17 +285,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
COffset
,
index_t
BStride
,
class
FloatA
,
class
FloatA
,
class
FloatB
,
class
FloatB
,
class
FloatC
>
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
intrin_mfma_f32_16x16x4f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_16x16x4f16
<
MPerXdlops
,
NPerXdlops
>
(
p_a
,
p_b
,
reg_c
);
}
}
};
};
...
@@ -322,17 +314,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
...
@@ -322,17 +314,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
COffset
,
index_t
BStride
,
class
FloatA
,
class
FloatA
,
class
FloatB
,
class
FloatB
,
class
FloatC
>
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
intrin_mfma_f32_4x4x4f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_4x4x4f16
<
MPerXdlops
,
NPerXdlops
>::
run
(
p_a
,
p_b
,
reg_c
);
}
}
};
};
...
@@ -596,43 +584,32 @@ struct XdlopsGemm
...
@@ -596,43 +584,32 @@ struct XdlopsGemm
{
{
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x8f16
,
32
,
32
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x8f16
,
32
,
32
>
{};
}
}
#if 0
template
<
>
template
<
>
static constexpr auto GetXdlopsInfo<half_t, 6
4
, 16>()
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
1
6
,
16
>
()
{
{
return xdlops_info<mfma_instr::mfma_f32_16x16x
4f16, 64, 16, 1, 1, c_vec16_1_t
>{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x
16f16
,
16
,
16
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
16
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
16
,
64
>
()
{
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64
, 1, 1, c_vec16_1_t
>{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4f16
,
16
,
64
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
8
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
8
,
64
>
()
{
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64
, 1, 1, c_vec4_2_t
>{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x4f16
,
8
,
64
>
{};
}
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
4
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
4
,
64
>
()
{
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64, 1, 1, c_vec4_1_t>{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_4x4x4f16
,
4
,
64
>
{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16, 1, 1, c_vec4_1_t>{};
}
}
#if 0
template <>
template <>
static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
{
{
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
e610402f
...
@@ -414,60 +414,88 @@ struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
...
@@ -414,60 +414,88 @@ struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
}
}
};
};
__device__
c_vec4_1_t
::
VecType
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
intrin_mfma_f32_16x16x16f16
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
struct
intrin_mfma_f32_16x16x16f16
;
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x16f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x4f16
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
);
template
<
>
template
<
index_t
COffset
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x4f16
<
16
,
64
>
(
const
half4_t
*
reg_a
,
struct
intrin_mfma_f32_16x16x16f16
<
16
,
16
,
COffset
>
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
{
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
2
,
0
,
0
);
template
<
class
FloatC
>
return
reg_c
;
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
}
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x16f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_16x16x4f16
<
64
,
16
>
(
const
half4_t
*
reg_a
,
struct
intrin_mfma_f32_16x16x4f16
;
const
half4_t
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
template
<
index_t
COffset
>
struct
intrin_mfma_f32_16x16x4f16
<
16
,
64
,
COffset
>
{
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
4
);
template
<
class
FloatC
>
return
reg_c
;
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
}
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
2
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
struct
intrin_mfma_f32_4x4x4f16
;
struct
intrin_mfma_f32_4x4x4f16
;
template
<
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_4x4x4f16
<
4
,
64
>
struct
intrin_mfma_f32_4x4x4f16
<
4
,
64
,
COffset
>
{
{
__device__
static
c_vec4_1_t
::
VecType
template
<
class
FloatC
>
r
un
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
__device__
static
void
R
un
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
return
reg_c
;
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
}
}
};
};
template
<
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_4x4x4f16
<
8
,
64
>
struct
intrin_mfma_f32_4x4x4f16
<
8
,
64
,
COffset
>
{
{
__device__
static
c_vec4_2_t
::
VecType
template
<
class
FloatC
>
r
un
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec4_2_t
::
VecType
reg_c
)
__device__
static
void
R
un
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
4
,
0
,
0
);
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
4
,
1
,
0
);
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
return
reg_c
;
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
1
,
0
);
}
}
};
};
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
e610402f
...
@@ -110,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -110,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmMPerWave
=
4
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPack
=
8
;
constexpr
index_t
GemmKPack
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
MRepeat
=
16
;
constexpr
index_t
NRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment