Commit 451aef90 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Add int4 example for convnd_fwd_bias_relu_add

parent 9efd033b
...@@ -9,3 +9,8 @@ target_link_libraries(example_grouped_convnd_fwd_bias_relu_add_xdl_bf16 PRIVATE ...@@ -9,3 +9,8 @@ target_link_libraries(example_grouped_convnd_fwd_bias_relu_add_xdl_bf16 PRIVATE
add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_int8 grouped_convnd_fwd_bias_relu_add_xdl_int8.cpp) add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_int8 grouped_convnd_fwd_bias_relu_add_xdl_int8.cpp)
target_link_libraries(example_grouped_convnd_fwd_bias_relu_add_xdl_int8 PRIVATE utility) target_link_libraries(example_grouped_convnd_fwd_bias_relu_add_xdl_int8 PRIVATE utility)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_convnd_fwd_bias_relu_add_xdl_int4 grouped_convnd_fwd_bias_relu_add_xdl_int4.cpp)
target_link_libraries(example_grouped_convnd_fwd_bias_relu_add_xdl_int4 PRIVATE utility)
endif() # USE_BITINT_EXTENSION_INT4
\ No newline at end of file
...@@ -26,13 +26,16 @@ void print_helper_msg() ...@@ -26,13 +26,16 @@ void print_helper_msg()
} }
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InKernelDataType,
typename WeiDataType, typename WeiKernelDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename OutDataType, typename OutKernelDataType,
typename InElementOp, typename InElementOp,
typename WeiElementOp, typename WeiElementOp,
typename OutElementOp, typename OutElementOp,
typename InUserDataType,
typename WeiUserDataType,
typename OutUserDataType,
typename DeviceConvNDFwdInstance> typename DeviceConvNDFwdInstance>
int run_grouped_conv_fwd_bias_relu_add(bool do_verification, int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
int init_method, int init_method,
...@@ -47,12 +50,12 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -47,12 +50,12 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op) const OutElementOp& out_element_op)
{ {
Tensor<InDataType> in(in_g_n_c_wis_desc); Tensor<InUserDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc); Tensor<WeiUserDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutDataType> bias(bias_g_n_k_wos_desc); Tensor<OutUserDataType> bias(bias_g_n_k_wos_desc);
Tensor<OutDataType> residual(residual_g_n_k_wos_desc); Tensor<OutUserDataType> residual(residual_g_n_k_wos_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc); Tensor<OutUserDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc); Tensor<OutUserDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl; std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl; std::cout << "wei: " << wei.mDesc << std::endl;
...@@ -64,21 +67,21 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -64,21 +67,21 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in.GenerateTensorValue(GeneratorTensor_2<InUserDataType>{-5, 5});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); wei.GenerateTensorValue(GeneratorTensor_2<WeiUserDataType>{-5, 5});
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); bias.GenerateTensorValue(GeneratorTensor_2<OutUserDataType>{-5, 5});
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in.GenerateTensorValue(GeneratorTensor_3<InUserDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); wei.GenerateTensorValue(GeneratorTensor_3<WeiUserDataType>{-0.5, 0.5});
bias.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5}); bias.GenerateTensorValue(GeneratorTensor_3<OutUserDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); DeviceMem in_device_buf(sizeof(InUserDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); DeviceMem wei_device_buf(sizeof(WeiUserDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpaceSize()); DeviceMem bias_device_buf(sizeof(OutUserDataType) * bias.mDesc.GetElementSpaceSize());
DeviceMem residual_device_buf(sizeof(OutDataType) * residual.mDesc.GetElementSpaceSize()); DeviceMem residual_device_buf(sizeof(OutUserDataType) * residual.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(sizeof(OutUserDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data()); in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
...@@ -154,7 +157,7 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -154,7 +157,7 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops(); std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>(); std::size_t num_btype = conv_param.GetByte<InUserDataType, WeiUserDataType, OutUserDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time; float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time; float gb_per_sec = num_btype / 1.E6 / avg_time;
...@@ -168,8 +171,8 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -168,8 +171,8 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
Tensor<CShuffleDataType> c_host(out_g_n_k_wos_desc); Tensor<CShuffleDataType> c_host(out_g_n_k_wos_desc);
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial, auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType, InUserDataType,
WeiDataType, WeiUserDataType,
CShuffleDataType, CShuffleDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
......
...@@ -7,13 +7,19 @@ ...@@ -7,13 +7,19 @@
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::bhalf_t; // kernel data types
using WeiDataType = ck::bhalf_t; using InKernelDataType = ck::bhalf_t;
using WeiKernelDataType = ck::bhalf_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = float;
using BiasDataType = ck::bhalf_t; using BiasKernelDataType = ck::bhalf_t;
using ResidualDataType = ck::bhalf_t; using ResidualKernelDataType = ck::bhalf_t;
using OutDataType = ck::bhalf_t; using OutKernelDataType = ck::bhalf_t;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance =
WeiLayout, WeiLayout,
ck::Tuple<BiasLayout, ResidualLayout>, ck::Tuple<BiasLayout, ResidualLayout>,
OutLayout, OutLayout,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<BiasDataType, ResidualDataType>, ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -181,13 +187,16 @@ int main(int argc, char* argv[]) ...@@ -181,13 +187,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<1, return run_grouped_conv_fwd_bias_relu_add<1,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<1, DeviceGroupedConvNDFwdInstance<1,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -290,13 +299,16 @@ int main(int argc, char* argv[]) ...@@ -290,13 +299,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<2, return run_grouped_conv_fwd_bias_relu_add<2,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<2, DeviceGroupedConvNDFwdInstance<2,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -413,13 +425,16 @@ int main(int argc, char* argv[]) ...@@ -413,13 +425,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<3, return run_grouped_conv_fwd_bias_relu_add<3,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<3, DeviceGroupedConvNDFwdInstance<3,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -7,13 +7,19 @@ ...@@ -7,13 +7,19 @@
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::half_t; // kernel data types
using WeiDataType = ck::half_t; using InKernelDataType = ck::half_t;
using WeiKernelDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = ck::half_t; using CShuffleDataType = ck::half_t;
using BiasDataType = ck::half_t; using BiasKernelDataType = ck::half_t;
using ResidualDataType = ck::half_t; using ResidualKernelDataType = ck::half_t;
using OutDataType = ck::half_t; using OutKernelDataType = ck::half_t;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance =
WeiLayout, WeiLayout,
ck::Tuple<BiasLayout, ResidualLayout>, ck::Tuple<BiasLayout, ResidualLayout>,
OutLayout, OutLayout,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<BiasDataType, ResidualDataType>, ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -181,13 +187,16 @@ int main(int argc, char* argv[]) ...@@ -181,13 +187,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<1, return run_grouped_conv_fwd_bias_relu_add<1,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<1, DeviceGroupedConvNDFwdInstance<1,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -290,13 +299,16 @@ int main(int argc, char* argv[]) ...@@ -290,13 +299,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<2, return run_grouped_conv_fwd_bias_relu_add<2,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<2, DeviceGroupedConvNDFwdInstance<2,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -413,13 +425,16 @@ int main(int argc, char* argv[]) ...@@ -413,13 +425,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<3, return run_grouped_conv_fwd_bias_relu_add<3,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<3, DeviceGroupedConvNDFwdInstance<3,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -7,13 +7,19 @@ ...@@ -7,13 +7,19 @@
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = float; // kernel data types
using WeiDataType = float; using InKernelDataType = float;
using WeiKernelDataType = float;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = float;
using BiasDataType = float; using BiasKernelDataType = float;
using ResidualDataType = float; using ResidualKernelDataType = float;
using OutDataType = float; using OutKernelDataType = float;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance =
WeiLayout, WeiLayout,
ck::Tuple<BiasLayout, ResidualLayout>, ck::Tuple<BiasLayout, ResidualLayout>,
OutLayout, OutLayout,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<BiasDataType, ResidualDataType>, ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -181,13 +187,16 @@ int main(int argc, char* argv[]) ...@@ -181,13 +187,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<1, return run_grouped_conv_fwd_bias_relu_add<1,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<1, DeviceGroupedConvNDFwdInstance<1,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -290,13 +299,16 @@ int main(int argc, char* argv[]) ...@@ -290,13 +299,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<2, return run_grouped_conv_fwd_bias_relu_add<2,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<2, DeviceGroupedConvNDFwdInstance<2,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -413,13 +425,16 @@ int main(int argc, char* argv[]) ...@@ -413,13 +425,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<3, return run_grouped_conv_fwd_bias_relu_add<3,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<3, DeviceGroupedConvNDFwdInstance<3,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -7,13 +7,19 @@ ...@@ -7,13 +7,19 @@
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = int8_t; // kernel data types
using WeiDataType = int8_t; using InKernelDataType = int8_t;
using WeiKernelDataType = int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using CShuffleDataType = int8_t; using CShuffleDataType = int8_t;
using BiasDataType = int8_t; using BiasKernelDataType = int8_t;
using ResidualDataType = int8_t; using ResidualKernelDataType = int8_t;
using OutDataType = int8_t; using OutKernelDataType = int8_t;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -40,12 +46,12 @@ using DeviceGroupedConvNDFwdInstance =
WeiLayout, WeiLayout,
ck::Tuple<BiasLayout, ResidualLayout>, ck::Tuple<BiasLayout, ResidualLayout>,
OutLayout, OutLayout,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<BiasDataType, ResidualDataType>, ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -181,13 +187,16 @@ int main(int argc, char* argv[]) ...@@ -181,13 +187,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<1, return run_grouped_conv_fwd_bias_relu_add<1,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<1, DeviceGroupedConvNDFwdInstance<1,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -290,13 +299,16 @@ int main(int argc, char* argv[]) ...@@ -290,13 +299,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<2, return run_grouped_conv_fwd_bias_relu_add<2,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<2, DeviceGroupedConvNDFwdInstance<2,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -413,13 +425,16 @@ int main(int argc, char* argv[]) ...@@ -413,13 +425,16 @@ int main(int argc, char* argv[])
}); });
return run_grouped_conv_fwd_bias_relu_add<3, return run_grouped_conv_fwd_bias_relu_add<3,
InDataType, InKernelDataType,
WeiDataType, WeiKernelDataType,
CShuffleDataType, CShuffleDataType,
OutDataType, OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
InUserDataType,
WeiUserDataType,
OutUserDataType,
DeviceGroupedConvNDFwdInstance<3, DeviceGroupedConvNDFwdInstance<3,
InLayout, InLayout,
WeiLayout, WeiLayout,
......
...@@ -98,6 +98,16 @@ struct AddReluAdd ...@@ -98,6 +98,16 @@ struct AddReluAdd
int32_t c = b + x2; int32_t c = b + x2;
y = c; y = c;
} }
template <>
__host__ __device__ constexpr void operator()<int4_t, int8_t, int4_t, int4_t>(
int4_t& y, const int8_t& x0, const int4_t& x1, const int4_t& x2) const
{
int32_t a = x0 + x1;
int32_t b = a > 0 ? a : 0;
int32_t c = b + x2;
y = c;
}
}; };
struct AddHardswishAdd struct AddHardswishAdd
......
...@@ -150,7 +150,12 @@ check_err(const std::vector<T>& out, ...@@ -150,7 +150,12 @@ check_err(const std::vector<T>& out,
} }
template <typename T> template <typename T>
typename std::enable_if<std::is_integral<T>::value && !std::is_same<T, bhalf_t>::value, bool>::type std::enable_if_t<(std::is_integral_v<T> && !std::is_same_v<T, bhalf_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, int4_t>
#endif
,
bool>
check_err(const std::vector<T>& out, check_err(const std::vector<T>& out,
const std::vector<T>& ref, const std::vector<T>& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment