Commit 4d140b5d authored by danyao12's avatar danyao12
Browse files

fp16_2&bf16_2 atomicCAS

parent 9cf17a90
...@@ -79,32 +79,65 @@ inline __host__ __device__ half2_t add_fp16x2_t(const half2_t& a, const half2_t& ...@@ -79,32 +79,65 @@ inline __host__ __device__ half2_t add_fp16x2_t(const half2_t& a, const half2_t&
return rtn; return rtn;
} }
union U32FP162_ADDR
{
uint32_t* u32_a;
half2_t* fp162_a;
};
union U32FP162
{
uint32_t u32;
half2_t fp162;
};
template <> template <>
__device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x) __device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
{ {
uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst); U32FP162_ADDR dword_addr;
uint32_t cur_v = *dword_addr; U32FP162 cur_v;
U32FP162 new_;
uint32_t old_v, new_v; uint32_t old_v, new_v;
dword_addr.fp162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do { do
old_v = cur_v; {
half2_t new_ = add_fp16x2_t(*reinterpret_cast<half2_t*>(&cur_v), x); old_v = cur_v.u32;
new_v = *reinterpret_cast<uint32_t*>(&new_); new_.fp162 = add_fp16x2_t(cur_v.fp162, x);
cur_v = atomicCAS(dword_addr, old_v, new_v); new_v = new_.u32;
}while(cur_v != old_v); cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x; return x;
} }
// template <>
// __device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// half2_t new_ = add_fp16x2_t(*reinterpret_cast<half2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// union U16BF16 { // union U16BF16 {
// uint16_t u16; // uint16_t u16;
// bhalf_t bf16; // bhalf_t bf16;
// }; // };
// inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b){ // inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b){
// U16BF16 xa {.bf16 = a}; // U16BF16 xa {.bf16 = a};
// U16BF16 xb {.bf16 = b}; // U16BF16 xb {.bf16 = b};
// U16BF16 xr; // U16BF16 xr;
// xr.u16 = xa.u16 + xb.u16; // xr.u16 = xa.u16 + xb.u16;
// return xr.bf16; // return xr.bf16;
...@@ -123,23 +156,56 @@ inline __host__ __device__ bhalf2_t add_bf16x2_t(const bhalf2_t& a, const bhalf2 ...@@ -123,23 +156,56 @@ inline __host__ __device__ bhalf2_t add_bf16x2_t(const bhalf2_t& a, const bhalf2
return rtn; return rtn;
} }
union U32BF162_ADDR
{
uint32_t* u32_a;
bhalf2_t* bf162_a;
};
union U32BF162
{
uint32_t u32;
bhalf2_t bf162;
};
template <> template <>
__device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x) __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
{ {
uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst); U32BF162_ADDR dword_addr;
uint32_t cur_v = *dword_addr; U32BF162 cur_v;
U32BF162 new_;
uint32_t old_v, new_v; uint32_t old_v, new_v;
dword_addr.bf162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do { do
old_v = cur_v; {
bhalf2_t new_ = add_bf16x2_t(*reinterpret_cast<bhalf2_t*>(&cur_v), x); old_v = cur_v.u32;
new_v = *reinterpret_cast<uint32_t*>(&new_); new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
cur_v = atomicCAS(dword_addr, old_v, new_v); new_v = new_.u32;
}while(cur_v != old_v); cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x; return x;
} }
// template <>
// __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// bhalf2_t new_ = add_bf16x2_t(*reinterpret_cast<bhalf2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to // intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for // instantiate this template. The purpose is to make the implementation of atomic_max explicit for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment