Commit f3edca63 authored by ltqin's avatar ltqin
Browse files

bfloat16_t save and atomic

parent da0bb989
......@@ -50,7 +50,7 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using BF16 = ck::bfloat16_t;
using F32 = float;
using U16 = unsigned short;
......
......@@ -296,6 +296,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bfloat16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
......@@ -419,6 +420,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
return bit_cast<bhalf8_t>(tmp);
}
}
else if constexpr(is_same<T, bfloat16_t>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_i16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_i16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return bit_cast<bfloat16x8_t>(tmp);
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
......@@ -552,6 +578,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bfloat16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
......@@ -697,6 +724,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0);
}
}
else if constexpr(is_same<T, bfloat16_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<bhalf_t, 8> tmp{bit_cast<vector_type<bhalf_t, 8>>(src_thread_data)};
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bhalf_t),
0);
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
......
......@@ -1052,7 +1052,7 @@ inline __host__ __device__ constexpr float type_convert<float, bfloat16_t>(bfloa
return u.fp32;
}
// convert fp32 to bfp16
// convert fp32 to bfp16, rtz
template <>
inline __host__ __device__ constexpr bfloat16_t type_convert<bfloat16_t, float>(float x)
{
......
......@@ -189,6 +189,27 @@ __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
return x;
}
template <>
__device__ bfloat16x2_t atomic_add<bfloat16x2_t>(bfloat16x2_t* p_dst, const bfloat16x2_t& x)
{
U32BF162_ADDR dword_addr;
U32BF162 cur_v;
U32BF162 new_;
uint32_t old_v, new_v;
dword_addr.bf162_a = reinterpret_cast<bhalf2_t*>(p_dst);
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf162 = add_bf16x2_t(cur_v.bf162, reinterpret_cast<bhalf2_t>(x));
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x;
}
// template <>
// __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
// {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment