Commit 5f1d777b authored by ltqin's avatar ltqin
Browse files

fix save bfloat16x4_t

parent 27d764eb
...@@ -41,7 +41,8 @@ struct PassThrough ...@@ -41,7 +41,8 @@ struct PassThrough
} }
template <> template <>
__host__ __device__ void operator()<bfloat16_t, bfloat16_t>(bfloat16_t& y, const bfloat16_t& x) const __host__ __device__ void operator()<bfloat16_t, bfloat16_t>(bfloat16_t& y,
const bfloat16_t& x) const
{ {
y = x; y = x;
} }
......
...@@ -544,7 +544,6 @@ struct MfmaSelector ...@@ -544,7 +544,6 @@ struct MfmaSelector
#endif #endif
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() static constexpr auto GetMfma<int8_t, 32, 32>()
{ {
...@@ -756,7 +755,8 @@ struct XdlopsGemm ...@@ -756,7 +755,8 @@ struct XdlopsGemm
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value|| is_same<base_type, bfloat16_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, bfloat16_t>::value ||
is_same<base_type, int8_t>::value, is_same<base_type, int8_t>::value,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
......
...@@ -424,18 +424,21 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -424,18 +424,21 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
return llvm_amdgcn_raw_buffer_load_i16( auto tmp = llvm_amdgcn_raw_buffer_load_i16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return bit_cast<bfloat16_t>(tmp);
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
return llvm_amdgcn_raw_buffer_load_i16x2( auto tmp = llvm_amdgcn_raw_buffer_load_i16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return bit_cast<bfloat16x2_t>(tmp);
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
return llvm_amdgcn_raw_buffer_load_i16x4( auto tmp = llvm_amdgcn_raw_buffer_load_i16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return bit_cast<bfloat16x4_t>(tmp);
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
......
...@@ -179,7 +179,7 @@ union U32BF162 ...@@ -179,7 +179,7 @@ union U32BF162
{ {
uint32_t u32; uint32_t u32;
bhalf2_t bf162; bhalf2_t bf162;
bfloat16x2_t bfloat16x2; bfloat16x2_t bfloat16x2;
}; };
template <> template <>
...@@ -211,14 +211,14 @@ __device__ bfloat16x2_t atomic_add<bfloat16x2_t>(bfloat16x2_t* p_dst, const bflo ...@@ -211,14 +211,14 @@ __device__ bfloat16x2_t atomic_add<bfloat16x2_t>(bfloat16x2_t* p_dst, const bflo
U32BF162 new_; U32BF162 new_;
uint32_t old_v, new_v; uint32_t old_v, new_v;
dword_addr.bfloat16x2_a = p_dst; dword_addr.bfloat16x2_a = p_dst;
cur_v.u32 = *dword_addr.u32_a; cur_v.u32 = *dword_addr.u32_a;
do do
{ {
old_v = cur_v.u32; old_v = cur_v.u32;
new_.bfloat16x2 = add_bf16x2_t(cur_v.bfloat16x2, x); new_.bfloat16x2 = add_bf16x2_t(cur_v.bfloat16x2, x);
new_v = new_.u32; new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v); } while(cur_v.u32 != old_v);
return x; return x;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment