Unverified Commit c60578de authored by Kevin McKay's avatar Kevin McKay Committed by GitHub
Browse files

[Bugfix][Hardware][AMD] Use dynamic WARP_SIZE in sampler vectorized_process (#31295)


Signed-off-by: default avatarc0de128 <kevin.mckay@outlook.com>
parent 80fead8b
#include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include <torch/cuda.h> #include <torch/cuda.h>
...@@ -97,7 +98,9 @@ static inline __device__ bool isPartialMatch(float x, uint32_t pattern) { ...@@ -97,7 +98,9 @@ static inline __device__ bool isPartialMatch(float x, uint32_t pattern) {
template <typename T, typename idxT, typename Func> template <typename T, typename idxT, typename Func>
__device__ void vectorized_process(size_t thread_rank, size_t num_threads, __device__ void vectorized_process(size_t thread_rank, size_t num_threads,
const T* in, idxT len, Func f) { const T* in, idxT len, Func f) {
constexpr int WARP_SIZE = 32; // Use dynamic WARP_SIZE from cuda_compat.h to support both
// Wave64 (MI300X/gfx942) and Wave32 (Strix Halo/gfx1151) architectures
constexpr int kWarpSize = WARP_SIZE;
using WideT = float4; using WideT = float4;
if constexpr (sizeof(T) >= sizeof(WideT)) { if constexpr (sizeof(T) >= sizeof(WideT)) {
for (idxT i = thread_rank; i < len; i += num_threads) { for (idxT i = thread_rank; i < len; i += num_threads) {
...@@ -132,8 +135,8 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, ...@@ -132,8 +135,8 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads,
} }
} }
static_assert(WARP_SIZE >= items_per_scalar); static_assert(kWarpSize >= items_per_scalar);
// and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt // and because items_per_scalar > skip_cnt, kWarpSize > skip_cnt
// no need to use loop // no need to use loop
if (thread_rank < skip_cnt) { if (thread_rank < skip_cnt) {
f(in[thread_rank], thread_rank); f(in[thread_rank], thread_rank);
...@@ -142,7 +145,7 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, ...@@ -142,7 +145,7 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads,
// len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt;
// and so // and so
// len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <=
// WARP_SIZE no need to use loop // kWarpSize no need to use loop
const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank;
if (remain_i < len) { if (remain_i < len) {
f(in[remain_i], remain_i); f(in[remain_i], remain_i);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment