Unverified Commit 54d4bd62 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Support 16bits shfl_sync (#1169)

* Add type-safe warp shuffle helpers for 16-bit float types in common.h

- Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`.
- Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations.
- Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability.

* lint fix
parent 7a80b6df
......@@ -379,3 +379,91 @@ namespace cutlass {
TL_DEVICE
bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); }
} // namespace cutlass
//
// Type-safe warp shuffle helpers for 16-bit float types
// These wrappers avoid relying on implicit conversions that may be disallowed
// (e.g., converting float -> cutlass::bfloat16_t) by explicitly promoting to
// float for the shuffle and then down-converting.
//
namespace tl {
// Generic passthroughs
template <typename T>
TL_DEVICE T shfl_xor_sync(unsigned mask, T val, int laneMask) {
return __shfl_xor_sync(mask, val, laneMask);
}
template <typename T>
TL_DEVICE T shfl_down_sync(unsigned mask, T val, int delta) {
return __shfl_down_sync(mask, val, delta);
}
template <typename T>
TL_DEVICE T shfl_up_sync(unsigned mask, T val, int delta) {
return __shfl_up_sync(mask, val, delta);
}
template <typename T> TL_DEVICE T shfl_sync(unsigned mask, T val, int srcLane) {
return __shfl_sync(mask, val, srcLane);
}
// Specializations for cutlass::half_t
template <>
TL_DEVICE half_t shfl_xor_sync(unsigned mask, half_t val, int laneMask) {
float f = static_cast<float>(val);
float r = __shfl_xor_sync(mask, f, laneMask);
return half_t(r);
}
template <>
TL_DEVICE half_t shfl_down_sync(unsigned mask, half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down_sync(mask, f, delta);
return half_t(r);
}
template <>
TL_DEVICE half_t shfl_up_sync(unsigned mask, half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up_sync(mask, f, delta);
return half_t(r);
}
template <> TL_DEVICE half_t shfl_sync(unsigned mask, half_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl_sync(mask, f, srcLane);
return half_t(r);
}
// Specializations for cutlass::bfloat16_t
template <>
TL_DEVICE bfloat16_t shfl_xor_sync(unsigned mask, bfloat16_t val,
int laneMask) {
float f = static_cast<float>(val);
float r = __shfl_xor_sync(mask, f, laneMask);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_down_sync(unsigned mask, bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down_sync(mask, f, delta);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_up_sync(unsigned mask, bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up_sync(mask, f, delta);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_sync(unsigned mask, bfloat16_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl_sync(mask, f, srcLane);
return bfloat16_t(r);
}
} // namespace tl
......@@ -102,7 +102,7 @@ struct AllReduce {
__syncthreads();
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
}
if constexpr (offset == scale) {
return x;
......@@ -122,7 +122,7 @@ struct AllReduce {
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
}
if constexpr (offset == scale) {
return x;
......@@ -234,7 +234,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off);
T n = tl::shfl_down_sync(MASK, val, off);
if (lane < SEG - off)
val += n;
}
......@@ -244,10 +244,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if (real_col < W)
dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, (T)0);
T segSum = tl::shfl_sync(MASK, val, 0);
if (lane == 0)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, (T)0);
carry = tl::shfl_sync(MASK, carry, 0);
}
} else {
for (int seg = 0; seg * SEG < W; ++seg) {
......@@ -260,7 +260,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off);
T n = tl::shfl_up_sync(MASK, val, off);
if (lane >= off)
val += n;
}
......@@ -270,10 +270,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if (real_col < W)
dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
T segSum = tl::shfl_sync(MASK, val, SEG - 1);
if (lane == SEG - 1)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, SEG - 1);
carry = tl::shfl_sync(MASK, carry, SEG - 1);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment