Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a13bf453
Commit
a13bf453
authored
Feb 26, 2022
by
rocking
Browse files
rename ushort to bhalf_t
parent
010ef9dc
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
89 additions
and
87 deletions
+89
-87
composable_kernel/include/tensor_operation/element_wise_operation.hpp
...ernel/include/tensor_operation/element_wise_operation.hpp
+1
-1
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+4
-4
composable_kernel/include/utility/amd_buffer_addressing.hpp
composable_kernel/include/utility/amd_buffer_addressing.hpp
+8
-8
composable_kernel/include/utility/data_type.hpp
composable_kernel/include/utility/data_type.hpp
+4
-4
device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
...rc/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
+39
-39
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+8
-8
host/host_tensor/CMakeLists.txt
host/host_tensor/CMakeLists.txt
+4
-2
host/host_tensor/include/host_tensor.hpp
host/host_tensor/include/host_tensor.hpp
+4
-2
host/host_tensor/include/host_tensor_generator.hpp
host/host_tensor/include/host_tensor_generator.hpp
+9
-10
host/host_tensor/src/host_tensor.cpp
host/host_tensor/src/host_tensor.cpp
+1
-2
profiler/include/profile_conv_fwd_impl.hpp
profiler/include/profile_conv_fwd_impl.hpp
+3
-3
test/conv2d_fwd/main.cpp
test/conv2d_fwd/main.cpp
+4
-4
No files found.
composable_kernel/include/tensor_operation/element_wise_operation.hpp
View file @
a13bf453
...
@@ -13,7 +13,7 @@ struct PassThrough
...
@@ -13,7 +13,7 @@ struct PassThrough
__host__
__device__
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
ushor
t
&
y
,
const
ushor
t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
bhalf_
t
&
y
,
const
bhalf_
t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
a13bf453
...
@@ -474,7 +474,7 @@ struct MfmaSelector
...
@@ -474,7 +474,7 @@ struct MfmaSelector
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
ushor
t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
bhalf_
t
,
32
,
32
>
()
{
{
#if defined(CK_AMD_GPU_GFX90A)
#if defined(CK_AMD_GPU_GFX90A)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
...
@@ -484,7 +484,7 @@ struct MfmaSelector
...
@@ -484,7 +484,7 @@ struct MfmaSelector
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
ushor
t
,
16
,
16
>
()
static
constexpr
auto
GetMfma
<
bhalf_
t
,
16
,
16
>
()
{
{
#if defined(CK_AMD_GPU_GFX90A)
#if defined(CK_AMD_GPU_GFX90A)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
...
@@ -662,8 +662,8 @@ struct XdlopsGemm
...
@@ -662,8 +662,8 @@ struct XdlopsGemm
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
{
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
ushor
t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
,
is_same
<
base_type
,
bhalf_
t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
,
"base base_type must be float, half,
ushort
, and int8_t!"
);
"base base_type must be float, half,
bfloat16
, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
...
...
composable_kernel/include/utility/amd_buffer_addressing.hpp
View file @
a13bf453
...
@@ -51,7 +51,7 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
...
@@ -51,7 +51,7 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4i8"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4i8"
);
// buffer load i16
// buffer load i16
__device__
ushor
t
__device__
bhalf_
t
llvm_amdgcn_raw_buffer_load_i16
(
int32x4_t
srsrc
,
llvm_amdgcn_raw_buffer_load_i16
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
voffset
,
index_t
soffset
,
index_t
soffset
,
...
@@ -149,7 +149,7 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
...
@@ -149,7 +149,7 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
// buffer store i16
// buffer store i16
__device__
void
__device__
void
llvm_amdgcn_raw_buffer_store_i16
(
ushor
t
vdata
,
llvm_amdgcn_raw_buffer_store_i16
(
bhalf_
t
vdata
,
int32x4_t
rsrc
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
voffset
,
index_t
soffset
,
index_t
soffset
,
...
@@ -266,7 +266,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
...
@@ -266,7 +266,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
ushor
t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_
t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
...
@@ -365,7 +365,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
...
@@ -365,7 +365,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
return
bit_cast
<
half8_t
>
(
tmp
);
return
bit_cast
<
half8_t
>
(
tmp
);
}
}
}
}
else
if
constexpr
(
is_same
<
T
,
ushor
t
>::
value
)
else
if
constexpr
(
is_same
<
T
,
bhalf_
t
>::
value
)
{
{
if
constexpr
(
N
==
1
)
if
constexpr
(
N
==
1
)
{
{
...
@@ -522,7 +522,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
...
@@ -522,7 +522,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
ushor
t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_
t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
...
@@ -625,7 +625,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
...
@@ -625,7 +625,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
#endif
#endif
}
}
}
}
else
if
constexpr
(
is_same
<
T
,
ushor
t
>::
value
)
else
if
constexpr
(
is_same
<
T
,
bhalf_
t
>::
value
)
{
{
if
constexpr
(
N
==
1
)
if
constexpr
(
N
==
1
)
{
{
...
@@ -653,7 +653,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
...
@@ -653,7 +653,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
}
else
if
constexpr
(
N
==
8
)
else
if
constexpr
(
N
==
8
)
{
{
3
vector_type
<
ushor
t
,
8
>
tmp
{
src_thread_data
};
vector_type
<
bhalf_
t
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
0
>
{}],
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
...
@@ -664,7 +664,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
...
@@ -664,7 +664,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
1
>
{}],
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
ushor
t
),
dst_wave_addr_offset
+
4
*
sizeof
(
bhalf_
t
),
0
);
0
);
}
}
}
}
...
...
composable_kernel/include/utility/data_type.hpp
View file @
a13bf453
...
@@ -108,9 +108,9 @@ struct scalar_type<half_t>
...
@@ -108,9 +108,9 @@ struct scalar_type<half_t>
};
};
template
<
>
template
<
>
struct
scalar_type
<
ushor
t
>
struct
scalar_type
<
bhalf_
t
>
{
{
using
type
=
ushor
t
;
using
type
=
bhalf_
t
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
...
@@ -937,7 +937,7 @@ __host__ __device__ Y type_convert(X x)
...
@@ -937,7 +937,7 @@ __host__ __device__ Y type_convert(X x)
// convert bfp16 to fp32
// convert bfp16 to fp32
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
(
ushor
t
x
)
inline
__host__
__device__
float
type_convert
(
bhalf_
t
x
)
{
{
union
union
{
{
...
@@ -950,7 +950,7 @@ inline __host__ __device__ float type_convert(ushort x)
...
@@ -950,7 +950,7 @@ inline __host__ __device__ float type_convert(ushort x)
// convert fp32 to bfp16
// convert fp32 to bfp16
template
<
>
template
<
>
inline
__host__
__device__
ushor
t
type_convert
(
float
x
)
inline
__host__
__device__
bhalf_
t
type_convert
(
float
x
)
{
{
union
union
{
{
...
...
device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
View file @
a13bf453
...
@@ -32,19 +32,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple<
...
@@ -32,19 +32,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple<
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
// clang-format on
>
;
>
;
...
@@ -54,19 +54,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple<
...
@@ -54,19 +54,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple<
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
// clang-format on
>
;
>
;
...
@@ -76,19 +76,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple
...
@@ -76,19 +76,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushor
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
bhalf_t
,
bhalf_t
,
bhalf_
t
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
// clang-format on
>
;
>
;
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
a13bf453
...
@@ -77,7 +77,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
...
@@ -77,7 +77,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
{
{
if
constexpr
(
is_same
<
TIn
,
ushor
t
>::
value
)
if
constexpr
(
is_same
<
TIn
,
bhalf_
t
>::
value
)
{
{
v
+=
ck
::
type_convert
<
float
>
(
in
(
n
,
c
,
hi
,
wi
))
*
v
+=
ck
::
type_convert
<
float
>
(
in
(
n
,
c
,
hi
,
wi
))
*
ck
::
type_convert
<
float
>
(
wei
(
k
,
c
,
y
,
x
));
ck
::
type_convert
<
float
>
(
wei
(
k
,
c
,
y
,
x
));
...
@@ -92,9 +92,9 @@ void host_convolution_forward(const Tensor<TIn>& in,
...
@@ -92,9 +92,9 @@ void host_convolution_forward(const Tensor<TIn>& in,
}
}
}
}
if
constexpr
(
is_same
<
TOut
,
ushor
t
>::
value
)
if
constexpr
(
is_same
<
TOut
,
bhalf_
t
>::
value
)
{
{
out
(
n
,
k
,
ho
,
wo
)
=
ck
::
type_convert
<
ushor
t
>
(
static_cast
<
float
>
(
v
));
out
(
n
,
k
,
ho
,
wo
)
=
ck
::
type_convert
<
bhalf_
t
>
(
static_cast
<
float
>
(
v
));
}
}
else
else
{
{
...
@@ -115,7 +115,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
...
@@ -115,7 +115,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
1
]
&&
wi
>=
0
&&
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
1
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
2
])
wi
<
in
.
mDesc
.
GetLengths
()[
2
])
{
{
if
constexpr
(
is_same
<
TIn
,
ushor
t
>::
value
)
if
constexpr
(
is_same
<
TIn
,
bhalf_
t
>::
value
)
{
{
v
+=
ck
::
type_convert
<
float
>
(
in
(
n
,
hi
,
wi
,
c
))
*
v
+=
ck
::
type_convert
<
float
>
(
in
(
n
,
hi
,
wi
,
c
))
*
ck
::
type_convert
<
float
>
(
wei
(
k
,
y
,
x
,
c
));
ck
::
type_convert
<
float
>
(
wei
(
k
,
y
,
x
,
c
));
...
@@ -129,9 +129,9 @@ void host_convolution_forward(const Tensor<TIn>& in,
...
@@ -129,9 +129,9 @@ void host_convolution_forward(const Tensor<TIn>& in,
}
}
}
}
}
}
if
constexpr
(
is_same
<
TOut
,
ushor
t
>::
value
)
if
constexpr
(
is_same
<
TOut
,
bhalf_
t
>::
value
)
{
{
out
(
n
,
ho
,
wo
,
k
)
=
ck
::
type_convert
<
ushor
t
>
(
static_cast
<
float
>
(
v
));
out
(
n
,
ho
,
wo
,
k
)
=
ck
::
type_convert
<
bhalf_
t
>
(
static_cast
<
float
>
(
v
));
}
}
else
else
{
{
...
@@ -259,9 +259,9 @@ int main(int argc, char* argv[])
...
@@ -259,9 +259,9 @@ int main(int argc, char* argv[])
using
acc_data_t
=
float
;
using
acc_data_t
=
float
;
using
out_data_t
=
half_t
;
using
out_data_t
=
half_t
;
#elif 0
#elif 0
using
in_data_t
=
ushor
t
;
using
in_data_t
=
bhalf_
t
;
using
acc_data_t
=
float
;
using
acc_data_t
=
float
;
using
out_data_t
=
ushor
t
;
using
out_data_t
=
bhalf_
t
;
#elif 1
#elif 1
using
in_data_t
=
int8_t
;
using
in_data_t
=
int8_t
;
using
acc_data_t
=
int32_t
;
using
acc_data_t
=
int32_t
;
...
...
host/host_tensor/CMakeLists.txt
View file @
a13bf453
include_directories
(
BEFORE
include_directories
(
BEFORE
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility
include
include
)
)
...
...
host/host_tensor/include/host_tensor.hpp
View file @
a13bf453
...
@@ -8,6 +8,8 @@
...
@@ -8,6 +8,8 @@
#include <utility>
#include <utility>
#include <cassert>
#include <cassert>
#include <iostream>
#include <iostream>
#include "data_type.hpp"
template
<
typename
Range
>
template
<
typename
Range
>
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
...
@@ -311,7 +313,7 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> s
...
@@ -311,7 +313,7 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> s
void
ostream_HostTensorDescriptor
(
const
HostTensorDescriptor
&
desc
,
std
::
ostream
&
os
=
std
::
cout
);
void
ostream_HostTensorDescriptor
(
const
HostTensorDescriptor
&
desc
,
std
::
ostream
&
os
=
std
::
cout
);
float
bf16_to_f32_
(
ushor
t
src_val
);
float
bf16_to_f32_
(
ck
::
bhalf_
t
src_val
);
template
<
typename
T
>
template
<
typename
T
>
void
check_error
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
void
check_error
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
...
@@ -320,7 +322,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -320,7 +322,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
float
max_diff
=
-
1
;
float
max_diff
=
-
1
;
float
ref_value
=
0
,
result_value
=
0
;
float
ref_value
=
0
,
result_value
=
0
;
if
constexpr
(
std
::
is_same
<
ushor
t
,
T
>::
value
)
if
constexpr
(
std
::
is_same
<
ck
::
bhalf_
t
,
T
>::
value
)
{
{
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
{
{
...
...
host/host_tensor/include/host_tensor_generator.hpp
View file @
a13bf453
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include <cmath>
#include <cmath>
#include "config.hpp"
#include "config.hpp"
#include "data_type.hpp"
template
<
typename
T
>
template
<
typename
T
>
struct
GeneratorTensor_0
struct
GeneratorTensor_0
...
@@ -28,14 +27,14 @@ struct GeneratorTensor_1
...
@@ -28,14 +27,14 @@ struct GeneratorTensor_1
};
};
template
<
>
template
<
>
struct
GeneratorTensor_1
<
ushor
t
>
struct
GeneratorTensor_1
<
ck
::
bhalf_
t
>
{
{
float
value
=
1.0
;
float
value
=
1.0
;
template
<
typename
...
Is
>
template
<
typename
...
Is
>
ushor
t
operator
()(
Is
...)
ck
::
bhalf_
t
operator
()(
Is
...)
{
{
return
ck
::
type_convert
<
ushor
t
>
(
value
);
return
ck
::
type_convert
<
ck
::
bhalf_
t
>
(
value
);
}
}
};
};
...
@@ -65,16 +64,16 @@ struct GeneratorTensor_2
...
@@ -65,16 +64,16 @@ struct GeneratorTensor_2
};
};
template
<
>
template
<
>
struct
GeneratorTensor_2
<
ushor
t
>
struct
GeneratorTensor_2
<
ck
::
bhalf_
t
>
{
{
int
min_value
=
0
;
int
min_value
=
0
;
int
max_value
=
1
;
int
max_value
=
1
;
template
<
typename
...
Is
>
template
<
typename
...
Is
>
ushor
t
operator
()(
Is
...)
ck
::
bhalf_
t
operator
()(
Is
...)
{
{
float
tmp
=
(
std
::
rand
()
%
(
max_value
-
min_value
))
+
min_value
;
float
tmp
=
(
std
::
rand
()
%
(
max_value
-
min_value
))
+
min_value
;
return
ck
::
type_convert
<
ushor
t
>
(
tmp
);
return
ck
::
type_convert
<
ck
::
bhalf_
t
>
(
tmp
);
}
}
};
};
...
@@ -107,19 +106,19 @@ struct GeneratorTensor_3
...
@@ -107,19 +106,19 @@ struct GeneratorTensor_3
};
};
template
<
>
template
<
>
struct
GeneratorTensor_3
<
ushor
t
>
struct
GeneratorTensor_3
<
ck
::
bhalf_
t
>
{
{
float
min_value
=
0
;
float
min_value
=
0
;
float
max_value
=
1
;
float
max_value
=
1
;
template
<
typename
...
Is
>
template
<
typename
...
Is
>
ushor
t
operator
()(
Is
...)
ck
::
bhalf_
t
operator
()(
Is
...)
{
{
float
tmp
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
tmp
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
fp32_tmp
=
min_value
+
tmp
*
(
max_value
-
min_value
);
float
fp32_tmp
=
min_value
+
tmp
*
(
max_value
-
min_value
);
return
ck
::
type_convert
<
ushor
t
>
(
fp32_tmp
);
return
ck
::
type_convert
<
ck
::
bhalf_
t
>
(
fp32_tmp
);
}
}
};
};
...
...
host/host_tensor/src/host_tensor.cpp
View file @
a13bf453
#include <cassert>
#include <cassert>
#include "host_tensor.hpp"
#include "host_tensor.hpp"
void
HostTensorDescriptor
::
CalculateStrides
()
void
HostTensorDescriptor
::
CalculateStrides
()
...
@@ -65,7 +64,7 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream
...
@@ -65,7 +64,7 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream
os
<<
"}"
<<
std
::
endl
;
os
<<
"}"
<<
std
::
endl
;
}
}
float
bf16_to_f32_
(
ushor
t
src_val
)
float
bf16_to_f32_
(
ck
::
bhalf_
t
src_val
)
{
{
union
union
{
{
...
...
profiler/include/profile_conv_fwd_impl.hpp
View file @
a13bf453
...
@@ -174,9 +174,9 @@ void profile_conv_fwd_impl(int do_verification,
...
@@ -174,9 +174,9 @@ void profile_conv_fwd_impl(int do_verification,
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
}
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ushor
t
>
&&
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
bhalf_
t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ushor
t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
bhalf_
t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ushor
t
>
)
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
bhalf_
t
>
)
{
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
conv_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
conv_ptrs
);
...
...
test/conv2d_fwd/main.cpp
View file @
a13bf453
...
@@ -198,9 +198,9 @@ int main(int argc, char* argv[])
...
@@ -198,9 +198,9 @@ int main(int argc, char* argv[])
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
}
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ushor
t
>
&&
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
bhalf_
t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ushor
t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
bhalf_
t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ushor
t
>
)
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
bhalf_
t
>
)
{
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
conv_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
conv_ptrs
);
...
@@ -292,7 +292,7 @@ int main(int argc, char* argv[])
...
@@ -292,7 +292,7 @@ int main(int argc, char* argv[])
}
}
else
if
(
data_type
==
2
)
else
if
(
data_type
==
2
)
{
{
Run
(
ushort
(),
ushort
(),
ushor
t
());
Run
(
bhalf_t
(),
bhalf_t
(),
bhalf_
t
());
}
}
else
if
(
data_type
==
3
)
else
if
(
data_type
==
3
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment