Commit a13bf453 authored by rocking's avatar rocking
Browse files

rename ushort to bhalf_t

parent 010ef9dc
...@@ -13,7 +13,7 @@ struct PassThrough ...@@ -13,7 +13,7 @@ struct PassThrough
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; } __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }
__host__ __device__ void operator()(ushort& y, const ushort& x) const { y = x; } __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { y = x; }
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; } __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }
......
...@@ -474,7 +474,7 @@ struct MfmaSelector ...@@ -474,7 +474,7 @@ struct MfmaSelector
} }
template <> template <>
static constexpr auto GetMfma<ushort, 32, 32>() static constexpr auto GetMfma<bhalf_t, 32, 32>()
{ {
#if defined(CK_AMD_GPU_GFX90A) #if defined(CK_AMD_GPU_GFX90A)
return MfmaInstr::mfma_f32_32x32x8bf16_1k; return MfmaInstr::mfma_f32_32x32x8bf16_1k;
...@@ -484,7 +484,7 @@ struct MfmaSelector ...@@ -484,7 +484,7 @@ struct MfmaSelector
} }
template <> template <>
static constexpr auto GetMfma<ushort, 16, 16>() static constexpr auto GetMfma<bhalf_t, 16, 16>()
{ {
#if defined(CK_AMD_GPU_GFX90A) #if defined(CK_AMD_GPU_GFX90A)
return MfmaInstr::mfma_f32_16x16x16bf16_1k; return MfmaInstr::mfma_f32_16x16x16bf16_1k;
...@@ -662,8 +662,8 @@ struct XdlopsGemm ...@@ -662,8 +662,8 @@ struct XdlopsGemm
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value || static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
is_same<base_type, ushort>::value || is_same<base_type, int8_t>::value, is_same<base_type, bhalf_t>::value || is_same<base_type, int8_t>::value,
"base base_type must be float, half, ushort, and int8_t!"); "base base_type must be float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread); mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
......
...@@ -51,7 +51,7 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, ...@@ -51,7 +51,7 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
// buffer load i16 // buffer load i16
__device__ ushort __device__ bhalf_t
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
...@@ -149,7 +149,7 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, ...@@ -149,7 +149,7 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
// buffer store i16 // buffer store i16
__device__ void __device__ void
llvm_amdgcn_raw_buffer_store_i16(ushort vdata, llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
...@@ -266,7 +266,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -266,7 +266,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -365,7 +365,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -365,7 +365,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
return bit_cast<half8_t>(tmp); return bit_cast<half8_t>(tmp);
} }
} }
else if constexpr(is_same<T, ushort>::value) else if constexpr(is_same<T, bhalf_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
...@@ -522,7 +522,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -522,7 +522,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, double>::value && (N == 1 || N == 2)) || (is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -625,7 +625,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -625,7 +625,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
#endif #endif
} }
} }
else if constexpr(is_same<T, ushort>::value) else if constexpr(is_same<T, bhalf_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
...@@ -653,7 +653,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -653,7 +653,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
3 vector_type<ushort, 8> tmp{src_thread_data}; vector_type<bhalf_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<0>{}], llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<0>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
...@@ -664,7 +664,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -664,7 +664,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}], llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(ushort), dst_wave_addr_offset + 4 * sizeof(bhalf_t),
0); 0);
} }
} }
......
...@@ -108,9 +108,9 @@ struct scalar_type<half_t> ...@@ -108,9 +108,9 @@ struct scalar_type<half_t>
}; };
template <> template <>
struct scalar_type<ushort> struct scalar_type<bhalf_t>
{ {
using type = ushort; using type = bhalf_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
...@@ -937,7 +937,7 @@ __host__ __device__ Y type_convert(X x) ...@@ -937,7 +937,7 @@ __host__ __device__ Y type_convert(X x)
// convert bfp16 to fp32 // convert bfp16 to fp32
template <> template <>
inline __host__ __device__ float type_convert(ushort x) inline __host__ __device__ float type_convert(bhalf_t x)
{ {
union union
{ {
...@@ -950,7 +950,7 @@ inline __host__ __device__ float type_convert(ushort x) ...@@ -950,7 +950,7 @@ inline __host__ __device__ float type_convert(ushort x)
// convert fp32 to bfp16 // convert fp32 to bfp16
template <> template <>
inline __host__ __device__ ushort type_convert(float x) inline __host__ __device__ bhalf_t type_convert(float x)
{ {
union union
{ {
......
...@@ -77,7 +77,7 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -77,7 +77,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3]) wi < in.mDesc.GetLengths()[3])
{ {
if constexpr(is_same<TIn, ushort>::value) if constexpr(is_same<TIn, bhalf_t>::value)
{ {
v += ck::type_convert<float>(in(n, c, hi, wi)) * v += ck::type_convert<float>(in(n, c, hi, wi)) *
ck::type_convert<float>(wei(k, c, y, x)); ck::type_convert<float>(wei(k, c, y, x));
...@@ -92,9 +92,9 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -92,9 +92,9 @@ void host_convolution_forward(const Tensor<TIn>& in,
} }
} }
if constexpr(is_same<TOut, ushort>::value) if constexpr(is_same<TOut, bhalf_t>::value)
{ {
out(n, k, ho, wo) = ck::type_convert<ushort>(static_cast<float>(v)); out(n, k, ho, wo) = ck::type_convert<bhalf_t>(static_cast<float>(v));
} }
else else
{ {
...@@ -115,7 +115,7 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -115,7 +115,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2]) wi < in.mDesc.GetLengths()[2])
{ {
if constexpr(is_same<TIn, ushort>::value) if constexpr(is_same<TIn, bhalf_t>::value)
{ {
v += ck::type_convert<float>(in(n, hi, wi, c)) * v += ck::type_convert<float>(in(n, hi, wi, c)) *
ck::type_convert<float>(wei(k, y, x, c)); ck::type_convert<float>(wei(k, y, x, c));
...@@ -129,9 +129,9 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -129,9 +129,9 @@ void host_convolution_forward(const Tensor<TIn>& in,
} }
} }
} }
if constexpr(is_same<TOut, ushort>::value) if constexpr(is_same<TOut, bhalf_t>::value)
{ {
out(n, ho, wo, k) = ck::type_convert<ushort>(static_cast<float>(v)); out(n, ho, wo, k) = ck::type_convert<bhalf_t>(static_cast<float>(v));
} }
else else
{ {
...@@ -259,9 +259,9 @@ int main(int argc, char* argv[]) ...@@ -259,9 +259,9 @@ int main(int argc, char* argv[])
using acc_data_t = float; using acc_data_t = float;
using out_data_t = half_t; using out_data_t = half_t;
#elif 0 #elif 0
using in_data_t = ushort; using in_data_t = bhalf_t;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = ushort; using out_data_t = bhalf_t;
#elif 1 #elif 1
using in_data_t = int8_t; using in_data_t = int8_t;
using acc_data_t = int32_t; using acc_data_t = int32_t;
......
include_directories(BEFORE include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
include include
) )
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <utility> #include <utility>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include "data_type.hpp"
template <typename Range> template <typename Range>
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
...@@ -311,7 +313,7 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> s ...@@ -311,7 +313,7 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> s
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
float bf16_to_f32_(ushort src_val); float bf16_to_f32_(ck::bhalf_t src_val);
template <typename T> template <typename T>
void check_error(const Tensor<T>& ref, const Tensor<T>& result) void check_error(const Tensor<T>& ref, const Tensor<T>& result)
...@@ -320,7 +322,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -320,7 +322,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
float max_diff = -1; float max_diff = -1;
float ref_value = 0, result_value = 0; float ref_value = 0, result_value = 0;
if constexpr(std::is_same<ushort, T>::value) if constexpr(std::is_same<ck::bhalf_t, T>::value)
{ {
for(int i = 0; i < ref.mData.size(); ++i) for(int i = 0; i < ref.mData.size(); ++i)
{ {
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <cmath> #include <cmath>
#include "config.hpp" #include "config.hpp"
#include "data_type.hpp"
template <typename T> template <typename T>
struct GeneratorTensor_0 struct GeneratorTensor_0
...@@ -28,14 +27,14 @@ struct GeneratorTensor_1 ...@@ -28,14 +27,14 @@ struct GeneratorTensor_1
}; };
template <> template <>
struct GeneratorTensor_1<ushort> struct GeneratorTensor_1<ck::bhalf_t>
{ {
float value = 1.0; float value = 1.0;
template <typename... Is> template <typename... Is>
ushort operator()(Is...) ck::bhalf_t operator()(Is...)
{ {
return ck::type_convert<ushort>(value); return ck::type_convert<ck::bhalf_t>(value);
} }
}; };
...@@ -65,16 +64,16 @@ struct GeneratorTensor_2 ...@@ -65,16 +64,16 @@ struct GeneratorTensor_2
}; };
template <> template <>
struct GeneratorTensor_2<ushort> struct GeneratorTensor_2<ck::bhalf_t>
{ {
int min_value = 0; int min_value = 0;
int max_value = 1; int max_value = 1;
template <typename... Is> template <typename... Is>
ushort operator()(Is...) ck::bhalf_t operator()(Is...)
{ {
float tmp = (std::rand() % (max_value - min_value)) + min_value; float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ushort>(tmp); return ck::type_convert<ck::bhalf_t>(tmp);
} }
}; };
...@@ -107,19 +106,19 @@ struct GeneratorTensor_3 ...@@ -107,19 +106,19 @@ struct GeneratorTensor_3
}; };
template <> template <>
struct GeneratorTensor_3<ushort> struct GeneratorTensor_3<ck::bhalf_t>
{ {
float min_value = 0; float min_value = 0;
float max_value = 1; float max_value = 1;
template <typename... Is> template <typename... Is>
ushort operator()(Is...) ck::bhalf_t operator()(Is...)
{ {
float tmp = float(std::rand()) / float(RAND_MAX); float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value); float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ushort>(fp32_tmp); return ck::type_convert<ck::bhalf_t>(fp32_tmp);
} }
}; };
......
#include <cassert> #include <cassert>
#include "host_tensor.hpp" #include "host_tensor.hpp"
void HostTensorDescriptor::CalculateStrides() void HostTensorDescriptor::CalculateStrides()
...@@ -65,7 +64,7 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream ...@@ -65,7 +64,7 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream
os << "}" << std::endl; os << "}" << std::endl;
} }
float bf16_to_f32_(ushort src_val) float bf16_to_f32_(ck::bhalf_t src_val)
{ {
union union
{ {
......
...@@ -174,9 +174,9 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -174,9 +174,9 @@ void profile_conv_fwd_impl(int do_verification,
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ushort> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ushort> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ushort>) ck::is_same_v<ck::remove_cv_t<OutDataType>, bhalf_t>)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
......
...@@ -198,9 +198,9 @@ int main(int argc, char* argv[]) ...@@ -198,9 +198,9 @@ int main(int argc, char* argv[])
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ushort> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ushort> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ushort>) ck::is_same_v<ck::remove_cv_t<OutDataType>, bhalf_t>)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
...@@ -292,7 +292,7 @@ int main(int argc, char* argv[]) ...@@ -292,7 +292,7 @@ int main(int argc, char* argv[])
} }
else if(data_type == 2) else if(data_type == 2)
{ {
Run(ushort(), ushort(), ushort()); Run(bhalf_t(), bhalf_t(), bhalf_t());
} }
else if(data_type == 3) else if(data_type == 3)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment