Commit 615e1d3e authored by Chao Liu's avatar Chao Liu
Browse files

update conv 1d and 3d instance

parent 11edd0f0
......@@ -17,6 +17,35 @@ namespace tensor_operation {
namespace device {
namespace instance {
// conv1d_fwd
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<1, NWC, KXC, NWK, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<1, NWC, KXC, NWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvFwd<1,
NWC,
KXC,
NWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d_fwd
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -90,6 +119,55 @@ void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough,
PassThrough>>>& instances);
// conv3d_fwd
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvFwd<3,
NDHWC,
KZYXC,
NDHWK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvFwd<3,
NDHWC,
KZYXC,
NDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvFwd<3,
NDHWC,
KZYXC,
NDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvFwd<3,
NDHWC,
KZYXC,
NDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
......@@ -124,8 +202,33 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
is_same_v<OutLayout, NWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
......@@ -153,6 +256,31 @@ struct DeviceOperationInstanceFactory<
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
......
......@@ -37,9 +37,9 @@ target_link_libraries(ckProfiler PRIVATE utility)
#target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
#target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
#target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
#target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment