"vscode:/vscode.git/clone" did not exist on "22443f7aaec112080e9a884bb047bc01db1c9ffd"
Unverified Commit b73ae242 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Add int4 example for convnd_fwd_bias_relu_add (#375)

* Add int4 example for convnd_fwd_bias_relu_add

* Fix AddReluAdd for building without int4 support

* Update CMakeLists.txt

* Format

* Convert int4 tensors for int8 kernel

* Fix device memory allocation

* Format

* Format
parent d520d0cf
add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_fp16 grouped_convnd_fwd_bias_relu_add_xdl_fp16.cpp) add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_fp16 grouped_convnd_fwd_bias_relu_add_xdl_fp16.cpp)
target_link_libraries(example_grouped_convnd_fwd_bias_relu_add_xdl_fp16 PRIVATE utility)
add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_fp32 grouped_convnd_fwd_bias_relu_add_xdl_fp32.cpp) add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_fp32 grouped_convnd_fwd_bias_relu_add_xdl_fp32.cpp)
target_link_libraries(example_grouped_convnd_fwd_bias_relu_add_xdl_fp32 PRIVATE utility)
add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_bf16 grouped_convnd_fwd_bias_relu_add_xdl_bf16.cpp) add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_bf16 grouped_convnd_fwd_bias_relu_add_xdl_bf16.cpp)
target_link_libraries(example_grouped_convnd_fwd_bias_relu_add_xdl_bf16 PRIVATE utility)
add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_int8 grouped_convnd_fwd_bias_relu_add_xdl_int8.cpp) add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_int8 grouped_convnd_fwd_bias_relu_add_xdl_int8.cpp)
target_link_libraries(example_grouped_convnd_fwd_bias_relu_add_xdl_int8 PRIVATE utility)
\ No newline at end of file if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_int4 grouped_convnd_fwd_bias_relu_add_xdl_int4.cpp)
endif() # USE_BITINT_EXTENSION_INT4
...@@ -26,13 +26,16 @@ void print_helper_msg() ...@@ -26,13 +26,16 @@ void print_helper_msg()
} }
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InKernelDataType,
typename WeiDataType, typename WeiKernelDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename OutDataType, typename OutKernelDataType,
typename InElementOp, typename InElementOp,
typename WeiElementOp, typename WeiElementOp,
typename OutElementOp, typename OutElementOp,
typename InUserDataType,
typename WeiUserDataType,
typename OutUserDataType,
typename DeviceConvNDFwdInstance> typename DeviceConvNDFwdInstance>
int run_grouped_conv_fwd_bias_relu_add(bool do_verification, int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
int init_method, int init_method,
...@@ -47,12 +50,12 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -47,12 +50,12 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op) const OutElementOp& out_element_op)
{ {
Tensor<InDataType> in(in_g_n_c_wis_desc); Tensor<InUserDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc); Tensor<WeiUserDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutDataType> bias(bias_g_n_k_wos_desc); Tensor<OutUserDataType> bias(bias_g_n_k_wos_desc);
Tensor<OutDataType> residual(residual_g_n_k_wos_desc); Tensor<OutUserDataType> residual(residual_g_n_k_wos_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc); Tensor<OutUserDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc); Tensor<OutKernelDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl; std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl; std::cout << "wei: " << wei.mDesc << std::endl;
...@@ -64,26 +67,38 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -64,26 +67,38 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in.GenerateTensorValue(GeneratorTensor_2<InUserDataType>{-5, 5});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); wei.GenerateTensorValue(GeneratorTensor_2<WeiUserDataType>{-5, 5});
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); bias.GenerateTensorValue(GeneratorTensor_2<OutUserDataType>{-5, 5});
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in.GenerateTensorValue(GeneratorTensor_3<InUserDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); wei.GenerateTensorValue(GeneratorTensor_3<WeiUserDataType>{-0.5, 0.5});
bias.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5}); bias.GenerateTensorValue(GeneratorTensor_3<OutUserDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); DeviceMem in_device_buf(sizeof(InKernelDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); DeviceMem wei_device_buf(sizeof(WeiKernelDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpaceSize()); DeviceMem bias_device_buf(sizeof(OutKernelDataType) * bias.mDesc.GetElementSpaceSize());
DeviceMem residual_device_buf(sizeof(OutDataType) * residual.mDesc.GetElementSpaceSize()); DeviceMem residual_device_buf(sizeof(OutKernelDataType) * residual.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(sizeof(OutKernelDataType) * out_device.mDesc.GetElementSpaceSize());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
const Tensor<InKernelDataType> in_converted(in);
const Tensor<WeiKernelDataType> wei_converted(wei);
const Tensor<OutKernelDataType> bias_converted(bias);
const Tensor<OutKernelDataType> residual_converted(residual);
in_device_buf.ToDevice(in_converted.mData.data());
wei_device_buf.ToDevice(wei_converted.mData.data());
bias_device_buf.ToDevice(bias_converted.mData.data());
residual_device_buf.ToDevice(residual_converted.mData.data());
#else // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
in_device_buf.ToDevice(in.mData.data()); in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data()); bias_device_buf.ToDevice(bias.mData.data());
residual_device_buf.ToDevice(residual.mData.data()); residual_device_buf.ToDevice(residual.mData.data());
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{}; std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
...@@ -154,7 +169,7 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -154,7 +169,7 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops(); std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>(); std::size_t num_btype = conv_param.GetByte<InUserDataType, WeiUserDataType, OutUserDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time; float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time; float gb_per_sec = num_btype / 1.E6 / avg_time;
...@@ -168,8 +183,8 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -168,8 +183,8 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
Tensor<CShuffleDataType> c_host(out_g_n_k_wos_desc); Tensor<CShuffleDataType> c_host(out_g_n_k_wos_desc);
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial, auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType, InUserDataType,
WeiDataType, WeiUserDataType,
CShuffleDataType, CShuffleDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
...@@ -196,10 +211,22 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -196,10 +211,22 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
const Tensor<OutUserDataType> out_device_converted(out_device);
return ck::utils::check_err(out_device_converted.mData,
out_host.mData,
"Error: incorrect results!",
1e-5f,
1e-4f)
? 0
: 1;
#else // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
return ck::utils::check_err( return ck::utils::check_err(
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f) out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f)
? 0 ? 0
: 1; : 1;
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
} }
return 0; return 0;
......
...@@ -7,13 +7,19 @@ ...@@ -7,13 +7,19 @@
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::bhalf_t; // kernel data types
using WeiDataType = ck::bhalf_t; using InKernelDataType = ck::bhalf_t;
using WeiKernelDataType = ck::bhalf_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = float;
using BiasDataType = ck::bhalf_t; using BiasKernelDataType = ck::bhalf_t;
using ResidualDataType = ck::bhalf_t; using ResidualKernelDataType = ck::bhalf_t;
using OutDataType = ck::bhalf_t; using OutKernelDataType = ck::bhalf_t;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance =
WeiLayout, WeiLayout,
ck::Tuple<BiasLayout, ResidualLayout>, ck::Tuple<BiasLayout, ResidualLayout>,
OutLayout, OutLayout,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<BiasDataType, ResidualDataType>, ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -181,13 +187,16 @@ int main(int argc, char* argv[]) ...@@ -181,13 +187,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<1, return run_grouped_conv_fwd_bias_relu_add<1,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<1, DeviceGroupedConvNDFwdInstance<1,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -290,13 +299,16 @@ int main(int argc, char* argv[]) ...@@ -290,13 +299,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<2, return run_grouped_conv_fwd_bias_relu_add<2,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<2, DeviceGroupedConvNDFwdInstance<2,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -413,13 +425,16 @@ int main(int argc, char* argv[]) ...@@ -413,13 +425,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<3, return run_grouped_conv_fwd_bias_relu_add<3,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<3, DeviceGroupedConvNDFwdInstance<3,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -7,13 +7,19 @@ ...@@ -7,13 +7,19 @@
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::half_t; // kernel data types
using WeiDataType = ck::half_t; using InKernelDataType = ck::half_t;
using WeiKernelDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = ck::half_t; using CShuffleDataType = ck::half_t;
using BiasDataType = ck::half_t; using BiasKernelDataType = ck::half_t;
using ResidualDataType = ck::half_t; using ResidualKernelDataType = ck::half_t;
using OutDataType = ck::half_t; using OutKernelDataType = ck::half_t;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance =
WeiLayout, WeiLayout,
ck::Tuple<BiasLayout, ResidualLayout>, ck::Tuple<BiasLayout, ResidualLayout>,
OutLayout, OutLayout,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<BiasDataType, ResidualDataType>, ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -181,13 +187,16 @@ int main(int argc, char* argv[]) ...@@ -181,13 +187,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<1, return run_grouped_conv_fwd_bias_relu_add<1,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<1, DeviceGroupedConvNDFwdInstance<1,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -290,13 +299,16 @@ int main(int argc, char* argv[]) ...@@ -290,13 +299,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<2, return run_grouped_conv_fwd_bias_relu_add<2,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<2, DeviceGroupedConvNDFwdInstance<2,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -413,13 +425,16 @@ int main(int argc, char* argv[]) ...@@ -413,13 +425,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<3, return run_grouped_conv_fwd_bias_relu_add<3,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<3, DeviceGroupedConvNDFwdInstance<3,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -7,13 +7,19 @@ ...@@ -7,13 +7,19 @@
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = float; // kernel data types
using WeiDataType = float; using InKernelDataType = float;
using WeiKernelDataType = float;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = float;
using BiasDataType = float; using BiasKernelDataType = float;
using ResidualDataType = float; using ResidualKernelDataType = float;
using OutDataType = float; using OutKernelDataType = float;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance =
WeiLayout, WeiLayout,
ck::Tuple<BiasLayout, ResidualLayout>, ck::Tuple<BiasLayout, ResidualLayout>,
OutLayout, OutLayout,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<BiasDataType, ResidualDataType>, ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -181,13 +187,16 @@ int main(int argc, char* argv[]) ...@@ -181,13 +187,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<1, return run_grouped_conv_fwd_bias_relu_add<1,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<1, DeviceGroupedConvNDFwdInstance<1,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -290,13 +299,16 @@ int main(int argc, char* argv[]) ...@@ -290,13 +299,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<2, return run_grouped_conv_fwd_bias_relu_add<2,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<2, DeviceGroupedConvNDFwdInstance<2,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -413,13 +425,16 @@ int main(int argc, char* argv[]) ...@@ -413,13 +425,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<3, return run_grouped_conv_fwd_bias_relu_add<3,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<3, DeviceGroupedConvNDFwdInstance<3,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -7,13 +7,19 @@ ...@@ -7,13 +7,19 @@
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = int8_t; // kernel data types
using WeiDataType = int8_t; using InKernelDataType = int8_t;
using WeiKernelDataType = int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using CShuffleDataType = int8_t; using CShuffleDataType = int8_t;
using BiasDataType = int8_t; using BiasKernelDataType = int8_t;
using ResidualDataType = int8_t; using ResidualKernelDataType = int8_t;
using OutDataType = int8_t; using OutKernelDataType = int8_t;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance =
WeiLayout, WeiLayout,
ck::Tuple<BiasLayout, ResidualLayout>, ck::Tuple<BiasLayout, ResidualLayout>,
OutLayout, OutLayout,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<BiasDataType, ResidualDataType>, ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -181,13 +187,16 @@ int main(int argc, char* argv[]) ...@@ -181,13 +187,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<1, return run_grouped_conv_fwd_bias_relu_add<1,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<1, DeviceGroupedConvNDFwdInstance<1,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -290,13 +299,16 @@ int main(int argc, char* argv[]) ...@@ -290,13 +299,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<2, return run_grouped_conv_fwd_bias_relu_add<2,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<2, DeviceGroupedConvNDFwdInstance<2,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -413,13 +425,16 @@ int main(int argc, char* argv[]) ...@@ -413,13 +425,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<3, return run_grouped_conv_fwd_bias_relu_add<3,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<3, DeviceGroupedConvNDFwdInstance<3,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -98,6 +98,18 @@ struct AddReluAdd ...@@ -98,6 +98,18 @@ struct AddReluAdd
int32_t c = b + x2; int32_t c = b + x2;
y = c; y = c;
} }
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
__host__ __device__ constexpr void operator()<int4_t, int8_t, int4_t, int4_t>(
int4_t& y, const int8_t& x0, const int4_t& x1, const int4_t& x2) const
{
int32_t a = x0 + x1;
int32_t b = a > 0 ? a : 0;
int32_t c = b + x2;
y = c;
}
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
}; };
struct AddHardswishAdd struct AddHardswishAdd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment