Commit a2232209 authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Convert int4 tensors for int8 kernel

parent bf40c70b
...@@ -55,7 +55,7 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -55,7 +55,7 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
Tensor<OutUserDataType> bias(bias_g_n_k_wos_desc); Tensor<OutUserDataType> bias(bias_g_n_k_wos_desc);
Tensor<OutUserDataType> residual(residual_g_n_k_wos_desc); Tensor<OutUserDataType> residual(residual_g_n_k_wos_desc);
Tensor<OutUserDataType> out_host(out_g_n_k_wos_desc); Tensor<OutUserDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutUserDataType> out_device(out_g_n_k_wos_desc); Tensor<OutKernelDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl; std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl; std::cout << "wei: " << wei.mDesc << std::endl;
...@@ -83,10 +83,22 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -83,10 +83,22 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
DeviceMem residual_device_buf(sizeof(OutUserDataType) * residual.mDesc.GetElementSpaceSize()); DeviceMem residual_device_buf(sizeof(OutUserDataType) * residual.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutUserDataType) * out_device.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(sizeof(OutUserDataType) * out_device.mDesc.GetElementSpaceSize());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
const Tensor<InKernelDataType> in_converted(in);
const Tensor<WeiKernelDataType> wei_converted(wei);
const Tensor<OutKernelDataType> bias_converted(bias);
const Tensor<OutKernelDataType> residual_converted(residual);
in_device_buf.ToDevice(in_converted.mData.data());
wei_device_buf.ToDevice(wei_converted.mData.data());
bias_device_buf.ToDevice(bias_converted.mData.data());
residual_device_buf.ToDevice(residual_converted.mData.data());
#else // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
in_device_buf.ToDevice(in.mData.data()); in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data()); bias_device_buf.ToDevice(bias.mData.data());
residual_device_buf.ToDevice(residual.mData.data()); residual_device_buf.ToDevice(residual.mData.data());
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{}; std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
...@@ -199,10 +211,22 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification, ...@@ -199,10 +211,22 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
const Tensor<OutUserDataType> out_device_converted(out_device);
return ck::utils::check_err(out_device_converted.mData,
out_host.mData,
"Error: incorrect results!",
1e-5f,
1e-4f)
? 0
: 1;
#else // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
return ck::utils::check_err( return ck::utils::check_err(
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f) out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f)
? 0 ? 0
: 1; : 1;
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
} }
return 0; return 0;
......
...@@ -162,7 +162,7 @@ struct joinable_thread : std::thread ...@@ -162,7 +162,7 @@ struct joinable_thread : std::thread
{ {
} }
joinable_thread(joinable_thread&&) = default; joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default; joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread() ~joinable_thread()
...@@ -254,7 +254,7 @@ struct Tensor ...@@ -254,7 +254,7 @@ struct Tensor
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {} Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
template <typename OutT> template <typename OutT>
Tensor<OutT> CopyAsType() Tensor<OutT> CopyAsType() const
{ {
Tensor<OutT> ret(mDesc); Tensor<OutT> ret(mDesc);
for(size_t i = 0; i < mData.size(); i++) for(size_t i = 0; i < mData.size(); i++)
...@@ -264,13 +264,18 @@ struct Tensor ...@@ -264,13 +264,18 @@ struct Tensor
return ret; return ret;
} }
Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {} Tensor() = delete;
Tensor(const Tensor&) = default;
Tensor(Tensor&&) = default;
Tensor& operator=(const Tensor& other) ~Tensor() = default;
Tensor& operator=(const Tensor&) = default;
Tensor& operator=(Tensor&&) = default;
template <typename FromT>
explicit Tensor(const Tensor<FromT>& other) : Tensor(other.template CopyAsType<T>())
{ {
mDesc = other.mDesc;
mData = other.mData;
return *this;
} }
const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); } const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment