Commit 9dce6851 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents 3cc57101 5d37d7bf
...@@ -51,19 +51,19 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, ...@@ -51,19 +51,19 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
// buffer load i16 // buffer load i16
__device__ ushort __device__ bhalf_t
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
__device__ ushort2_t __device__ bhalf2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
__device__ ushort4_t __device__ bhalf4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
...@@ -149,21 +149,21 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, ...@@ -149,21 +149,21 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
// buffer store i16 // buffer store i16
__device__ void __device__ void
llvm_amdgcn_raw_buffer_store_i16(ushort vdata, llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
__device__ void __device__ void
llvm_amdgcn_raw_buffer_store_i16x2(ushort2_t vdata, llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
__device__ void __device__ void
llvm_amdgcn_raw_buffer_store_i16x4(ushort4_t vdata, llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
...@@ -266,7 +266,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -266,7 +266,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -365,7 +365,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -365,7 +365,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
return bit_cast<half8_t>(tmp); return bit_cast<half8_t>(tmp);
} }
} }
else if constexpr(is_same<T, ushort>::value) else if constexpr(is_same<T, bhalf_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
...@@ -387,7 +387,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -387,7 +387,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return bit_cast<ushort8_t>(tmp); return bit_cast<bhalf8_t>(tmp);
} }
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
...@@ -522,7 +522,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -522,7 +522,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, double>::value && (N == 1 || N == 2)) || (is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -625,7 +625,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -625,7 +625,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
#endif #endif
} }
} }
else if constexpr(is_same<T, ushort>::value) else if constexpr(is_same<T, bhalf_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
...@@ -653,19 +653,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -653,19 +653,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
vector_type<half_t, 8> tmp{src_thread_data}; vector_type<bhalf_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}], llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<0>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}], llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t), dst_wave_addr_offset + 4 * sizeof(bhalf_t),
0); 0);
} }
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
......
...@@ -207,7 +207,7 @@ template <> ...@@ -207,7 +207,7 @@ template <>
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32> struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
...@@ -221,7 +221,7 @@ template <> ...@@ -221,7 +221,7 @@ template <>
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16> struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
...@@ -235,7 +235,7 @@ template <> ...@@ -235,7 +235,7 @@ template <>
struct intrin_mfma_f32_32x32x4bf16<32, 32> struct intrin_mfma_f32_32x32x4bf16<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
...@@ -249,7 +249,7 @@ template <> ...@@ -249,7 +249,7 @@ template <>
struct intrin_mfma_f32_16x16x8bf16<16, 16> struct intrin_mfma_f32_16x16x8bf16<16, 16>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
namespace ck { namespace ck {
using half_t = _Float16; using bhalf_t = ushort;
using half_t = _Float16;
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
...@@ -107,9 +108,9 @@ struct scalar_type<half_t> ...@@ -107,9 +108,9 @@ struct scalar_type<half_t>
}; };
template <> template <>
struct scalar_type<ushort> struct scalar_type<bhalf_t>
{ {
using type = ushort; using type = bhalf_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
...@@ -904,12 +905,12 @@ using half32_t = typename vector_type<half_t, 32>::type; ...@@ -904,12 +905,12 @@ using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type; using half64_t = typename vector_type<half_t, 64>::type;
// bfp16 // bfp16
using ushort2_t = typename vector_type<ushort, 2>::type; using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type; using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type; using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type; using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type; using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type; using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
// i32 // i32
using int32x2_t = typename vector_type<int32_t, 2>::type; using int32x2_t = typename vector_type<int32_t, 2>::type;
...@@ -936,7 +937,7 @@ __host__ __device__ Y type_convert(X x) ...@@ -936,7 +937,7 @@ __host__ __device__ Y type_convert(X x)
// convert bfp16 to fp32 // convert bfp16 to fp32
template <> template <>
inline __host__ __device__ float type_convert(ushort x) inline __host__ __device__ float type_convert(bhalf_t x)
{ {
union union
{ {
...@@ -949,7 +950,7 @@ inline __host__ __device__ float type_convert(ushort x) ...@@ -949,7 +950,7 @@ inline __host__ __device__ float type_convert(ushort x)
// convert fp32 to bfp16 // convert fp32 to bfp16
template <> template <>
inline __host__ __device__ ushort type_convert(float x) inline __host__ __device__ bhalf_t type_convert(float x)
{ {
union union
{ {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp" #include "c_style_pointer_cast.hpp"
#include "config.hpp"
#include "enable_if.hpp" #include "enable_if.hpp"
namespace ck { namespace ck {
...@@ -108,6 +109,30 @@ struct DynamicBuffer ...@@ -108,6 +109,30 @@ struct DynamicBuffer
} }
} }
template <InMemoryDataOperationEnum_t Op,
typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void Update(index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == InMemoryDataOperationEnum_t::Set)
{
this->template Set<X>(i, is_valid_element, x);
}
else if constexpr(Op == InMemoryDataOperationEnum_t::AtomicAdd)
{
this->template AtomicAdd<X>(i, is_valid_element, x);
}
else if constexpr(Op == InMemoryDataOperationEnum_t::Add)
{
auto tmp = this->template Get<X>(i, is_valid_element);
this->template Set<X>(i, is_valid_element, x + tmp);
// tmp += x;
// this->template Set<X>(i, is_valid_element, tmp);
}
}
template <typename X, template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment