"vscode:/vscode.git/clone" did not exist on "2bb9444f4681a2d0431a2282c8198c470d6fa36c"
Commit 5f1d777b authored by ltqin's avatar ltqin
Browse files

fix save bfloat16x4_t

parent 27d764eb
......@@ -41,7 +41,8 @@ struct PassThrough
}
template <>
__host__ __device__ void operator()<bfloat16_t, bfloat16_t>(bfloat16_t& y, const bfloat16_t& x) const
__host__ __device__ void operator()<bfloat16_t, bfloat16_t>(bfloat16_t& y,
const bfloat16_t& x) const
{
y = x;
}
......
......@@ -544,7 +544,6 @@ struct MfmaSelector
#endif
}
template <>
static constexpr auto GetMfma<int8_t, 32, 32>()
{
......@@ -756,7 +755,8 @@ struct XdlopsGemm
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value|| is_same<base_type, bfloat16_t>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, bfloat16_t>::value ||
is_same<base_type, int8_t>::value,
"base base_type must be double, float, half, bfloat16, and int8_t!");
......
......@@ -424,18 +424,21 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i16(
auto tmp = llvm_amdgcn_raw_buffer_load_i16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return bit_cast<bfloat16_t>(tmp);
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_i16x2(
auto tmp = llvm_amdgcn_raw_buffer_load_i16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return bit_cast<bfloat16x2_t>(tmp);
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_i16x4(
auto tmp = llvm_amdgcn_raw_buffer_load_i16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return bit_cast<bfloat16x4_t>(tmp);
}
else if constexpr(N == 8)
{
......
......@@ -179,7 +179,7 @@ union U32BF162
{
uint32_t u32;
bhalf2_t bf162;
bfloat16x2_t bfloat16x2;
bfloat16x2_t bfloat16x2;
};
template <>
......@@ -211,14 +211,14 @@ __device__ bfloat16x2_t atomic_add<bfloat16x2_t>(bfloat16x2_t* p_dst, const bflo
U32BF162 new_;
uint32_t old_v, new_v;
dword_addr.bfloat16x2_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
old_v = cur_v.u32;
new_.bfloat16x2 = add_bf16x2_t(cur_v.bfloat16x2, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment