Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3bbd5988
"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "991ce41a0d51da85a8f3bd0b1925ac02495ebb94"
Commit
3bbd5988
authored
May 31, 2021
by
Jing Zhang
Browse files
adding fp16 mfma
parent
5c27dcd5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
67 additions
and
107 deletions
+67
-107
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+4
-4
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+5
-22
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+47
-77
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+2
-2
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+9
-2
No files found.
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
3bbd5988
...
@@ -327,7 +327,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -327,7 +327,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple
(
Number
<
NumBlks
>
{},
Number
<
BlkSize
>
{}));
make_tuple
(
Number
<
NumBlks
>
{},
Number
<
BlkSize
>
{}));
StaticBuffer
<
AddressSpace
::
Vgpr
,
StaticBuffer
<
AddressSpace
::
Vgpr
,
vector_type
<
FloatA
B
,
c_blk_nb_bs_desc
.
GetElementSpaceSize
()
>
,
vector_type
<
FloatA
cc
,
c_blk_nb_bs_desc
.
GetElementSpaceSize
()
>
,
c_mr_nr_nx_desc
.
GetElementSpaceSize
()
>
c_mr_nr_nx_desc
.
GetElementSpaceSize
()
>
c_thread_buf
;
c_thread_buf
;
...
@@ -488,7 +488,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -488,7 +488,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
M0
>
{},
Number
<
1
>
{},
Number
<
M2
>
{},
Number
<
1
>
{}));
make_tuple
(
Number
<
M0
>
{},
Number
<
1
>
{},
Number
<
M2
>
{},
Number
<
1
>
{}));
StaticBuffer
<
AddressSpace
::
Vgpr
,
Float
AB
,
BlkSize
>
c_blk_buf_
;
StaticBuffer
<
AddressSpace
::
Vgpr
,
Float
C
,
BlkSize
>
c_blk_buf_
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
mr_i
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
mr_i
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
nr_i
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
nr_i
)
{
...
@@ -498,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -498,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple
(
mr_i
,
nr_i
,
xdlops_i
))
>
{}];
make_tuple
(
mr_i
,
nr_i
,
xdlops_i
))
>
{}];
static_for
<
0
,
BlkSize
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
BlkSize
,
1
>
{}([
&
](
auto
j
)
{
c_blk_buf_
(
j
)
=
c_blk
.
template
AsType
<
FloatA
B
>()[
Number
<
c_blk_buf_
(
j
)
=
c_blk
.
template
AsType
<
FloatA
cc
>()[
Number
<
c_blk_nb_bs_desc
.
CalculateOffset
(
make_tuple
(
blk_i
,
j
))
>
{}];
c_blk_nb_bs_desc
.
CalculateOffset
(
make_tuple
(
blk_i
,
j
))
>
{}];
});
});
...
@@ -518,7 +518,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -518,7 +518,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
CGlobalIteratorHacks
{};
CGlobalIteratorHacks
{};
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
Float
Acc
,
Float
C
,
FloatC
,
FloatC
,
decltype
(
c_m0_m1_m2_n_thread_desc
),
decltype
(
c_m0_m1_m2_n_thread_desc
),
decltype
(
c_m0_m1_m2_n_global_desc
),
decltype
(
c_m0_m1_m2_n_global_desc
),
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
3bbd5988
...
@@ -198,18 +198,13 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
...
@@ -198,18 +198,13 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
COffset
,
index_t
BStride
,
class
FloatA
,
class
FloatA
,
class
FloatB
,
class
FloatB
,
class
FloatC
>
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
const
auto
p_a
=
reinterpret_cast
<
const
half4_t
*>
(
a
);
intrin_mfma_f32_32x32x4f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
const
auto
p_b
=
reinterpret_cast
<
const
half4_t
*>
(
b
);
return
intrin_mfma_f32_32x32x4f16
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>::
run
(
p_a
,
p_b
,
reg_c
);
}
}
};
};
...
@@ -588,25 +583,13 @@ struct XdlopsGemm
...
@@ -588,25 +583,13 @@ struct XdlopsGemm
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
,
16
,
16
>
{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
,
16
,
16
>
{};
}
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<half_t, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template
<
>
template
<
>
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
64
>
()
static
constexpr
auto
GetXdlopsInfo
<
half_t
,
64
,
64
>
()
{
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64
, 1, 1, c_vec32_2_t
>{};
return
xdlops_info
<
mfma_instr
::
mfma_f32_32x32x4f16
,
64
,
64
>
{};
}
}
#if 0
template <>
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 32>()
static constexpr auto GetXdlopsInfo<half_t, 64, 32>()
{
{
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
3bbd5988
...
@@ -204,8 +204,8 @@ struct intrin_mfma_f32_32x32x1f32;
...
@@ -204,8 +204,8 @@ struct intrin_mfma_f32_32x32x1f32;
template
<
index_t
COffset
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
,
COffset
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
,
COffset
>
{
{
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
F
loat
A
&
reg_a
,
const
F
loat
B
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
f
loat
&
reg_a
,
const
f
loat
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
...
@@ -229,8 +229,8 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
...
@@ -229,8 +229,8 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
template
<
index_t
COffset
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x1f32
<
32
,
64
,
COffset
>
struct
intrin_mfma_f32_32x32x1f32
<
32
,
64
,
COffset
>
{
{
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
F
loat
A
&
reg_a
,
const
F
loat
B
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
f
loat
&
reg_a
,
const
f
loat
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
...
@@ -249,8 +249,8 @@ struct intrin_mfma_f32_32x32x2f32;
...
@@ -249,8 +249,8 @@ struct intrin_mfma_f32_32x32x2f32;
template
<
index_t
COffset
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
,
COffset
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
,
COffset
>
{
{
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
F
loat
A
&
reg_a
,
const
F
loat
B
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
f
loat
&
reg_a
,
const
f
loat
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
...
@@ -269,8 +269,8 @@ struct intrin_mfma_f32_16x16x4f32;
...
@@ -269,8 +269,8 @@ struct intrin_mfma_f32_16x16x4f32;
template
<
index_t
COffset
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_16x16x4f32
<
16
,
16
,
COffset
>
struct
intrin_mfma_f32_16x16x4f32
<
16
,
16
,
COffset
>
{
{
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
F
loat
A
&
reg_a
,
const
F
loat
B
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
f
loat
&
reg_a
,
const
f
loat
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
...
@@ -289,8 +289,8 @@ struct intrin_mfma_f32_16x16x1f32;
...
@@ -289,8 +289,8 @@ struct intrin_mfma_f32_16x16x1f32;
template
<
index_t
COffset
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_16x16x1f32
<
16
,
64
,
COffset
>
struct
intrin_mfma_f32_16x16x1f32
<
16
,
64
,
COffset
>
{
{
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
F
loat
A
&
reg_a
,
const
F
loat
B
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
f
loat
&
reg_a
,
const
f
loat
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
...
@@ -310,8 +310,8 @@ struct intrin_mfma_f32_4x4x1f32;
...
@@ -310,8 +310,8 @@ struct intrin_mfma_f32_4x4x1f32;
template
<
index_t
COffset
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_4x4x1f32
<
4
,
64
,
COffset
>
struct
intrin_mfma_f32_4x4x1f32
<
4
,
64
,
COffset
>
{
{
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
F
loat
A
&
reg_a
,
const
F
loat
B
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
f
loat
&
reg_a
,
const
f
loat
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
...
@@ -327,8 +327,8 @@ struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
...
@@ -327,8 +327,8 @@ struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
template
<
index_t
COffset
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_4x4x1f32
<
8
,
64
,
COffset
>
struct
intrin_mfma_f32_4x4x1f32
<
8
,
64
,
COffset
>
{
{
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
F
loat
A
&
reg_a
,
const
F
loat
B
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
f
loat
&
reg_a
,
const
f
loat
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
...
@@ -349,78 +349,48 @@ struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
...
@@ -349,78 +349,48 @@ struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
AStride
,
index_t
BStride
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
struct
intrin_mfma_f32_32x32x4f16
;
struct
intrin_mfma_f32_32x32x4f16
;
template
<
index_t
AStride
,
index_t
BStride
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x4f16
<
128
,
64
,
AStride
,
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
64
,
COffset
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
128
,
AStride
,
BStride
>
{
__device__
static
c_vec32_4_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
reg_c
.
s
.
z
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
reg_c
.
s
.
w
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_2_t
::
VecType
run
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
32
,
AStride
,
BStride
>
{
{
__device__
static
c_vec32_1_t
::
VecType
template
<
class
FloatC
>
r
un
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
__device__
static
void
R
un
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
1
);
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
return
reg_c
;
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
1
,
0
);
}
}
};
};
template
<
index_t
AStride
,
index_t
BStride
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x4f16
<
32
,
64
,
AStride
,
BStride
>
struct
intrin_mfma_f32_32x32x4f16
<
32
,
64
,
COffset
>
{
{
__device__
static
c_vec32_1_t
::
VecType
template
<
class
FloatC
>
r
un
(
const
half4_t
*
reg_a
,
const
half4_t
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
__device__
static
void
R
un
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
return
reg_c
;
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
}
}
};
};
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
3bbd5988
...
@@ -110,11 +110,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -110,11 +110,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerWave
=
8
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPerWave
=
4
;
constexpr
index_t
GemmKPerWave
=
4
;
constexpr
index_t
MRepeat
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
1
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
...
...
driver/src/conv_driver.cpp
View file @
3bbd5988
...
@@ -78,8 +78,8 @@ int main(int argc, char* argv[])
...
@@ -78,8 +78,8 @@ int main(int argc, char* argv[])
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
constexpr
index_t
N
=
1
;
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
C
=
16
;
...
@@ -651,6 +651,11 @@ int main(int argc, char* argv[])
...
@@ -651,6 +651,11 @@ int main(int argc, char* argv[])
constexpr
index_t
in_vector_size
=
1
;
constexpr
index_t
in_vector_size
=
1
;
using
acc_data_t
=
float
;
using
acc_data_t
=
float
;
using
out_data_t
=
float
;
using
out_data_t
=
float
;
#elif 1
using
in_data_t
=
half_t
;
constexpr
index_t
in_vector_size
=
1
;
using
acc_data_t
=
float
;
using
out_data_t
=
half_t
;
#elif 0
#elif 0
using
in_data_t
=
float
;
using
in_data_t
=
float
;
constexpr
index_t
in_vector_size
=
1
;
constexpr
index_t
in_vector_size
=
1
;
...
@@ -819,6 +824,7 @@ int main(int argc, char* argv[])
...
@@ -819,6 +824,7 @@ int main(int argc, char* argv[])
check_error
(
out_nkhw_host
,
out_nkhw_device
);
check_error
(
out_nkhw_host
,
out_nkhw_device
);
#if 0
if(do_log)
if(do_log)
{
{
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
...
@@ -826,5 +832,6 @@ int main(int argc, char* argv[])
...
@@ -826,5 +832,6 @@ int main(int argc, char* argv[])
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
}
}
#endif
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment