Commit 79ba5b99 authored by ltqin's avatar ltqin
Browse files

fix about atomic

parent f3edca63
......@@ -156,16 +156,30 @@ inline __host__ __device__ bhalf2_t add_bf16x2_t(const bhalf2_t& a, const bhalf2
return rtn;
}
inline __host__ __device__ bfloat16_t add_bf16_t(const bfloat16_t& a, const bfloat16_t& b)
{
return type_convert<bfloat16_t>(type_convert<float>(a) + type_convert<float>(b));
}
inline __host__ __device__ bfloat16x2_t add_bf16x2_t(const bfloat16x2_t& a, const bfloat16x2_t& b)
{
bfloat16x2_t rtn;
rtn[0] = add_bf16_t(a[0], b[0]);
rtn[1] = add_bf16_t(a[1], b[1]);
return rtn;
}
union U32BF162_ADDR
{
uint32_t* u32_a;
bhalf2_t* bf162_a;
bfloat16x2_t* bfloat16x2_a;
};
union U32BF162
{
uint32_t u32;
bhalf2_t bf162;
bfloat16x2_t bfloat16x2;
};
template <>
......@@ -196,13 +210,13 @@ __device__ bfloat16x2_t atomic_add<bfloat16x2_t>(bfloat16x2_t* p_dst, const bflo
U32BF162 cur_v;
U32BF162 new_;
uint32_t old_v, new_v;
dword_addr.bf162_a = reinterpret_cast<bhalf2_t*>(p_dst);
dword_addr.bfloat16x2_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf162 = add_bf16x2_t(cur_v.bf162, reinterpret_cast<bhalf2_t>(x));
new_.bfloat16x2 = add_bf16x2_t(cur_v.bfloat16x2, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment