Unverified Commit 54d4bd62 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Support 16bits shfl_sync (#1169)

* Add type-safe warp shuffle helpers for 16-bit float types in common.h

- Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`.
- Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations.
- Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability.

* lint fix
parent 7a80b6df
...@@ -379,3 +379,91 @@ namespace cutlass { ...@@ -379,3 +379,91 @@ namespace cutlass {
TL_DEVICE TL_DEVICE
bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); } bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); }
} // namespace cutlass } // namespace cutlass
//
// Type-safe warp shuffle helpers for 16-bit float types
// These wrappers avoid relying on implicit conversions that may be disallowed
// (e.g., converting float -> cutlass::bfloat16_t) by explicitly promoting to
// float for the shuffle and then down-converting.
//
namespace tl {
// Generic passthroughs
template <typename T>
TL_DEVICE T shfl_xor_sync(unsigned mask, T val, int laneMask) {
return __shfl_xor_sync(mask, val, laneMask);
}
template <typename T>
TL_DEVICE T shfl_down_sync(unsigned mask, T val, int delta) {
return __shfl_down_sync(mask, val, delta);
}
template <typename T>
TL_DEVICE T shfl_up_sync(unsigned mask, T val, int delta) {
return __shfl_up_sync(mask, val, delta);
}
template <typename T> TL_DEVICE T shfl_sync(unsigned mask, T val, int srcLane) {
return __shfl_sync(mask, val, srcLane);
}
// Specializations for cutlass::half_t
template <>
TL_DEVICE half_t shfl_xor_sync(unsigned mask, half_t val, int laneMask) {
float f = static_cast<float>(val);
float r = __shfl_xor_sync(mask, f, laneMask);
return half_t(r);
}
template <>
TL_DEVICE half_t shfl_down_sync(unsigned mask, half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down_sync(mask, f, delta);
return half_t(r);
}
template <>
TL_DEVICE half_t shfl_up_sync(unsigned mask, half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up_sync(mask, f, delta);
return half_t(r);
}
template <> TL_DEVICE half_t shfl_sync(unsigned mask, half_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl_sync(mask, f, srcLane);
return half_t(r);
}
// Specializations for cutlass::bfloat16_t
template <>
TL_DEVICE bfloat16_t shfl_xor_sync(unsigned mask, bfloat16_t val,
int laneMask) {
float f = static_cast<float>(val);
float r = __shfl_xor_sync(mask, f, laneMask);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_down_sync(unsigned mask, bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down_sync(mask, f, delta);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_up_sync(unsigned mask, bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up_sync(mask, f, delta);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_sync(unsigned mask, bfloat16_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl_sync(mask, f, srcLane);
return bfloat16_t(r);
}
} // namespace tl
...@@ -102,7 +102,7 @@ struct AllReduce { ...@@ -102,7 +102,7 @@ struct AllReduce {
__syncthreads(); __syncthreads();
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else { } else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset))); x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
} }
if constexpr (offset == scale) { if constexpr (offset == scale) {
return x; return x;
...@@ -122,7 +122,7 @@ struct AllReduce { ...@@ -122,7 +122,7 @@ struct AllReduce {
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads)); asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else { } else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset))); x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
} }
if constexpr (offset == scale) { if constexpr (offset == scale) {
return x; return x;
...@@ -234,7 +234,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { ...@@ -234,7 +234,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
#pragma unroll #pragma unroll
for (int off = 1; off < SEG; off <<= 1) { for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off); T n = tl::shfl_down_sync(MASK, val, off);
if (lane < SEG - off) if (lane < SEG - off)
val += n; val += n;
} }
...@@ -244,10 +244,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { ...@@ -244,10 +244,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if (real_col < W) if (real_col < W)
dst[real_row * W + real_col] = val; dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, (T)0); T segSum = tl::shfl_sync(MASK, val, 0);
if (lane == 0) if (lane == 0)
carry = segSum; carry = segSum;
carry = (T)__shfl_sync(MASK, carry, (T)0); carry = tl::shfl_sync(MASK, carry, 0);
} }
} else { } else {
for (int seg = 0; seg * SEG < W; ++seg) { for (int seg = 0; seg * SEG < W; ++seg) {
...@@ -260,7 +260,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { ...@@ -260,7 +260,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
#pragma unroll #pragma unroll
for (int off = 1; off < SEG; off <<= 1) { for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off); T n = tl::shfl_up_sync(MASK, val, off);
if (lane >= off) if (lane >= off)
val += n; val += n;
} }
...@@ -270,10 +270,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { ...@@ -270,10 +270,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if (real_col < W) if (real_col < W)
dst[real_row * W + real_col] = val; dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, SEG - 1); T segSum = tl::shfl_sync(MASK, val, SEG - 1);
if (lane == SEG - 1) if (lane == SEG - 1)
carry = segSum; carry = segSum;
carry = (T)__shfl_sync(MASK, carry, SEG - 1); carry = tl::shfl_sync(MASK, carry, SEG - 1);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment