Commit 010ef9dc authored by rocking's avatar rocking
Browse files

replace ushortXXX_t to bhalfXXX_t

parent 63e10e34
...@@ -57,13 +57,13 @@ llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, ...@@ -57,13 +57,13 @@ llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
__device__ ushort2_t __device__ bhalf2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
__device__ ushort4_t __device__ bhalf4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
...@@ -156,14 +156,14 @@ llvm_amdgcn_raw_buffer_store_i16(ushort vdata, ...@@ -156,14 +156,14 @@ llvm_amdgcn_raw_buffer_store_i16(ushort vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
__device__ void __device__ void
llvm_amdgcn_raw_buffer_store_i16x2(ushort2_t vdata, llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
__device__ void __device__ void
llvm_amdgcn_raw_buffer_store_i16x4(ushort4_t vdata, llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
...@@ -387,7 +387,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -387,7 +387,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return bit_cast<ushort8_t>(tmp); return bit_cast<bhalf8_t>(tmp);
} }
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
...@@ -655,13 +655,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -655,13 +655,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
{ {
3 vector_type<ushort, 8> tmp{src_thread_data}; 3 vector_type<ushort, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<ushort4_t>()[Number<0>{}], llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<0>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<ushort4_t>()[Number<1>{}], llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(ushort), dst_wave_addr_offset + 4 * sizeof(ushort),
......
...@@ -207,7 +207,7 @@ template <> ...@@ -207,7 +207,7 @@ template <>
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32> struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
...@@ -221,7 +221,7 @@ template <> ...@@ -221,7 +221,7 @@ template <>
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16> struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
...@@ -235,7 +235,7 @@ template <> ...@@ -235,7 +235,7 @@ template <>
struct intrin_mfma_f32_32x32x4bf16<32, 32> struct intrin_mfma_f32_32x32x4bf16<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
...@@ -249,7 +249,7 @@ template <> ...@@ -249,7 +249,7 @@ template <>
struct intrin_mfma_f32_16x16x8bf16<16, 16> struct intrin_mfma_f32_16x16x8bf16<16, 16>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
namespace ck { namespace ck {
using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
// vector_type // vector_type
...@@ -904,12 +905,12 @@ using half32_t = typename vector_type<half_t, 32>::type; ...@@ -904,12 +905,12 @@ using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type; using half64_t = typename vector_type<half_t, 64>::type;
// bfp16 // bfp16
using ushort2_t = typename vector_type<ushort, 2>::type; using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type; using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type; using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type; using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type; using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type; using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
// i32 // i32
using int32x2_t = typename vector_type<int32_t, 2>::type; using int32x2_t = typename vector_type<int32_t, 2>::type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment