Commit 4d140b5d authored by danyao12's avatar danyao12
Browse files

fp16_2&bf16_2 atomicCAS

parent 9cf17a90
......@@ -79,23 +79,56 @@ inline __host__ __device__ half2_t add_fp16x2_t(const half2_t& a, const half2_t&
return rtn;
}
union U32FP162_ADDR
{
uint32_t* u32_a;
half2_t* fp162_a;
};
union U32FP162
{
uint32_t u32;
half2_t fp162;
};
template <>
__device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
{
uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
uint32_t cur_v = *dword_addr;
U32FP162_ADDR dword_addr;
U32FP162 cur_v;
U32FP162 new_;
uint32_t old_v, new_v;
dword_addr.fp162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do {
old_v = cur_v;
half2_t new_ = add_fp16x2_t(*reinterpret_cast<half2_t*>(&cur_v), x);
new_v = *reinterpret_cast<uint32_t*>(&new_);
cur_v = atomicCAS(dword_addr, old_v, new_v);
}while(cur_v != old_v);
do
{
old_v = cur_v.u32;
new_.fp162 = add_fp16x2_t(cur_v.fp162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x;
}
// template <>
// __device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// half2_t new_ = add_fp16x2_t(*reinterpret_cast<half2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// union U16BF16 {
// uint16_t u16;
// bhalf_t bf16;
......@@ -123,23 +156,56 @@ inline __host__ __device__ bhalf2_t add_bf16x2_t(const bhalf2_t& a, const bhalf2
return rtn;
}
union U32BF162_ADDR
{
uint32_t* u32_a;
bhalf2_t* bf162_a;
};
union U32BF162
{
uint32_t u32;
bhalf2_t bf162;
};
template <>
__device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
{
uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
uint32_t cur_v = *dword_addr;
U32BF162_ADDR dword_addr;
U32BF162 cur_v;
U32BF162 new_;
uint32_t old_v, new_v;
dword_addr.bf162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do {
old_v = cur_v;
bhalf2_t new_ = add_bf16x2_t(*reinterpret_cast<bhalf2_t*>(&cur_v), x);
new_v = *reinterpret_cast<uint32_t*>(&new_);
cur_v = atomicCAS(dword_addr, old_v, new_v);
}while(cur_v != old_v);
do
{
old_v = cur_v.u32;
new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x;
}
// template <>
// __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// bhalf2_t new_ = add_bf16x2_t(*reinterpret_cast<bhalf2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment