Commit 9dce6851 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents 3cc57101 5d37d7bf
......@@ -51,19 +51,19 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
// buffer load i16
__device__ ushort
__device__ bhalf_t
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
__device__ ushort2_t
__device__ bhalf2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
__device__ ushort4_t
__device__ bhalf4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
......@@ -149,21 +149,21 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
// buffer store i16
__device__ void
llvm_amdgcn_raw_buffer_store_i16(ushort vdata,
llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
__device__ void
llvm_amdgcn_raw_buffer_store_i16x2(ushort2_t vdata,
llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
__device__ void
llvm_amdgcn_raw_buffer_store_i16x4(ushort4_t vdata,
llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
......@@ -266,7 +266,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
......@@ -365,7 +365,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
return bit_cast<half8_t>(tmp);
}
}
else if constexpr(is_same<T, ushort>::value)
else if constexpr(is_same<T, bhalf_t>::value)
{
if constexpr(N == 1)
{
......@@ -387,7 +387,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return bit_cast<ushort8_t>(tmp);
return bit_cast<bhalf8_t>(tmp);
}
}
else if constexpr(is_same<T, int32_t>::value)
......@@ -522,7 +522,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
......@@ -625,7 +625,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
#endif
}
}
else if constexpr(is_same<T, ushort>::value)
else if constexpr(is_same<T, bhalf_t>::value)
{
if constexpr(N == 1)
{
......@@ -653,19 +653,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};
vector_type<bhalf_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bhalf_t),
0);
}
}
else if constexpr(is_same<T, int32_t>::value)
......
......@@ -207,7 +207,7 @@ template <>
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
{
template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
__device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
......@@ -221,7 +221,7 @@ template <>
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
{
template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
__device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
......@@ -235,7 +235,7 @@ template <>
struct intrin_mfma_f32_32x32x4bf16<32, 32>
{
template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
......@@ -249,7 +249,7 @@ template <>
struct intrin_mfma_f32_16x16x8bf16<16, 16>
{
template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
......
......@@ -5,7 +5,8 @@
namespace ck {
using half_t = _Float16;
using bhalf_t = ushort;
using half_t = _Float16;
// vector_type
template <typename T, index_t N>
......@@ -107,9 +108,9 @@ struct scalar_type<half_t>
};
template <>
struct scalar_type<ushort>
struct scalar_type<bhalf_t>
{
using type = ushort;
using type = bhalf_t;
static constexpr index_t vector_size = 1;
};
......@@ -904,12 +905,12 @@ using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type;
using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
......@@ -936,7 +937,7 @@ __host__ __device__ Y type_convert(X x)
// convert bfp16 to fp32
template <>
inline __host__ __device__ float type_convert(ushort x)
inline __host__ __device__ float type_convert(bhalf_t x)
{
union
{
......@@ -949,7 +950,7 @@ inline __host__ __device__ float type_convert(ushort x)
// convert fp32 to bfp16
template <>
inline __host__ __device__ ushort type_convert(float x)
inline __host__ __device__ bhalf_t type_convert(float x)
{
union
{
......
......@@ -3,6 +3,7 @@
#include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp"
#include "config.hpp"
#include "enable_if.hpp"
namespace ck {
......@@ -108,6 +109,30 @@ struct DynamicBuffer
}
}
template <InMemoryDataOperationEnum_t Op,
typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void Update(index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == InMemoryDataOperationEnum_t::Set)
{
this->template Set<X>(i, is_valid_element, x);
}
else if constexpr(Op == InMemoryDataOperationEnum_t::AtomicAdd)
{
this->template AtomicAdd<X>(i, is_valid_element, x);
}
else if constexpr(Op == InMemoryDataOperationEnum_t::Add)
{
auto tmp = this->template Get<X>(i, is_valid_element);
this->template Set<X>(i, is_valid_element, x + tmp);
// tmp += x;
// this->template Set<X>(i, is_valid_element, tmp);
}
}
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment