Unverified Commit 18bb216c authored by Chayenne's avatar Chayenne Committed by GitHub
Browse files

Revert "[MOE] enable efficient moe_alignment multi-blocks execution (3x~6x)" (#3982)

parent 6b859e7d
......@@ -99,12 +99,13 @@ def moe_align_block_size_triton(
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
tokens_cnts: torch.Tensor,
cumsum: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts,)
tokens_cnts = torch.zeros(
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
)
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts)
moe_align_block_size_stage1[grid](
......@@ -138,18 +139,11 @@ def moe_align_block_size_triton(
)
def calculate_diff(batch_size, seq_len, num_experts):
num_experts = num_experts
def calculate_diff(batch_size, seq_len):
num_experts = 256
block_size = 128
topk = 8
assert batch_size >= 1
assert seq_len >= 1
assert num_experts >= 4
if topk > num_experts:
topk = num_experts
topk_ids = torch.stack(
[
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
......@@ -181,13 +175,6 @@ def calculate_diff(batch_size, seq_len, num_experts):
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
token_cnts_buffer_triton = torch.zeros(
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer_triton = torch.zeros(
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
)
# compare the performance of cuda and triton implementation
moe_align_block_size(
topk_ids,
......@@ -206,27 +193,14 @@ def calculate_diff(batch_size, seq_len, num_experts):
sorted_ids_triton,
expert_ids_triton,
num_tokens_post_pad_triton,
token_cnts_buffer_triton,
cumsum_buffer_triton,
)
sorted_ids_cuda_snapshot = sorted_ids_cuda[: cumsum_buffer[1]].sort().values
sorted_ids_triton_snapshot = sorted_ids_triton[: cumsum_buffer[1]].sort().values
if (
torch.allclose(expert_ids_cuda, expert_ids_triton)
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton)
and torch.allclose(sorted_ids_cuda_snapshot, sorted_ids_triton_snapshot)
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
):
print(
"✅ CUDA and Triton implementations match : num_tokens={}, num_experts={}".format(
batch_size * seq_len, num_experts
)
)
print("✅ CUDA and Triton implementations match")
else:
print("❌ CUDA and Triton implementations do not match")
print("CUDA sorted ids:", sorted_ids_cuda_snapshot)
print("Triton sorted ids:", sorted_ids_triton_snapshot)
print("CUDA expert_ids:", expert_ids_cuda)
print("Triton expert_ids:", expert_ids_triton)
print("CUDA num_tokens_post_pad:", num_tokens_post_pad_cuda)
......@@ -282,7 +256,7 @@ def benchmark(batch_size, seq_len, provider):
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.zeros(
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
......@@ -293,37 +267,34 @@ def benchmark(batch_size, seq_len, provider):
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
# Warm up
api_func = (
moe_align_block_size if provider == "cuda" else moe_align_block_size_triton
)
for _ in range(10):
api_func(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
token_cnts_buffer.clone(),
cumsum_buffer.clone(),
quantiles = [0.5, 0.2, 0.8]
if provider == "cuda":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
token_cnts_buffer,
cumsum_buffer,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
torch.cuda.synchronize()
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: api_func(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
token_cnts_buffer.clone(),
cumsum_buffer.clone(),
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......@@ -335,22 +306,8 @@ if __name__ == "__main__":
default="./configs/benchmark_ops/moe_align_blocks/",
help="Path to save moe align benchmark results",
)
parser.add_argument(
"--verify",
action="store_true",
help="verify kernel",
)
args = parser.parse_args()
if args.verify:
num_experts_range = [2**i for i in range(3, 9)]
calculate_diff(batch_size=4, seq_len=1024)
configs = list(
itertools.product(batch_size_range, seq_length_range, num_experts_range)
)
for bs, seq, num_experts in configs:
calculate_diff(batch_size=bs, seq_len=seq, num_experts=num_experts)
else:
benchmark.run(print_data=True, save_path=args.save_path)
benchmark.run(print_data=True)
[build-system]
requires = ["setuptools>=61.0", "wheel", "torch<=2.5.1"]
requires = ["setuptools>=61.0", "wheel", "torch"]
build-backend = "setuptools.build_meta"
[project]
......
......@@ -16,284 +16,77 @@ limitations under the License.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cooperative_groups.h>
#include <torch/extension.h>
#include <THC/THCAtomics.cuh>
#include "utils.h"
#define MAX_NUM_EXPERTS 256
#define EXPERTS_PER_WARP ((MAX_NUM_EXPERTS) / (WARP_SIZE))
#define WARP_SIZE 32
#define FRAGS_PER_BLOCK 4
#define FRAG_SIZE_M 16
#define FRAG_SIZE_N 16
#ifndef USE_ROCM
#define kWarpsToLoad 2
#else
#define kWarpsToLoad 1
#endif
#define kElementsPerAccess 4
#define kElementsPerThr 16
#define SGLANG_FORCE_INLINE_DEVICE_FUNC static __forceinline__ __attribute__((always_inline)) __device__
namespace cg = cooperative_groups;
SGLANG_FORCE_INLINE_DEVICE_FUNC void store_global_cumsum(int* cumsum /*dest*/, int* total_tokens_post_pad /*dest*/,
const int32_t* local_offsets, const int& tid,
const int& num_experts, cg::grid_group& grid) {
int active_threads = CEILDIV(num_experts + 1, kElementsPerThr);
if (tid < active_threads - 1) {
for (int i = tid * kElementsPerThr; i < (tid + 1) * kElementsPerThr; i += kElementsPerAccess) {
*(int4*)(cumsum + i) = *(int4*)(local_offsets + i);
}
}
template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ cumsum_buffer, size_t numel) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
if (tid == active_threads - 1) {
#pragma unroll
for (int i = tid * kElementsPerThr; i < num_experts + 1; i++) {
*(cumsum + i) = *(local_offsets + i);
}
}
if (tid == active_threads) {
*total_tokens_post_pad = local_offsets[num_experts];
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
}
__threadfence_system();
grid.sync();
}
SGLANG_FORCE_INLINE_DEVICE_FUNC void align_global_cumsum(int32_t* local_offsets /*src_and_dest*/,
int32_t* local_offsets_buf, int* smem_ptr, const int tid,
const int32_t& block_size, const int32_t& num_experts) {
int active_threads = CEILDIV(num_experts, kElementsPerThr);
int start = tid * kElementsPerThr + 1;
int end = MIN((tid + 1) * kElementsPerThr, num_experts) + 1;
if (tid == 0) {
smem_ptr[0] = 0;
}
if (tid < active_threads) {
for (int i = start; i < end; ++i) {
smem_ptr[i] = local_offsets[i] - local_offsets[i - 1];
}
}
__syncthreads();
if (tid < active_threads) {
for (int i = start; i < end; ++i) {
int last_val = (i - 1) % kElementsPerThr == 0 ? 0 : local_offsets[i - 1];
local_offsets[i] = last_val + CEILDIV(smem_ptr[i], block_size) * block_size;
}
local_offsets_buf[tid] = local_offsets[end - 1];
}
__syncthreads();
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) {
__shared__ int32_t shared_counts[WARP_SIZE][8];
if (tid < active_threads && tid > 0) {
int offset = 0;
for (int j = 0; j < tid; ++j) {
offset += local_offsets_buf[j];
}
const int warp_id = threadIdx.x / WARP_SIZE;
const int experts_per_warp = 8;
const int my_expert_start = warp_id * experts_per_warp;
for (int i = start; i < end; ++i) {
local_offsets[i] += offset;
for (int i = 0; i < experts_per_warp; ++i) {
if (my_expert_start + i < num_experts) {
shared_counts[warp_id][i] = 0;
}
}
__syncthreads();
}
SGLANG_FORCE_INLINE_DEVICE_FUNC void reduce_unaligned_cumsum(int* tokens_cnts_ptr /*src_and_dest*/, int* smem_ptr,
int32_t* local_offsets, const int& tid, const int& lane_id,
const int& warp_id, const int32_t& num_experts,
cg::grid_group& grid) {
int total_fragments = CEILDIV(num_experts, FRAG_SIZE_N);
int fragments_per_block = CEILDIV(total_fragments, gridDim.x);
int fragments_per_warp = CEILDIV(fragments_per_block, FRAGS_PER_BLOCK);
for (int i = 0; i < gridDim.x; i += FRAG_SIZE_M) {
for (int j = 0; j < fragments_per_warp; j++) {
if (warp_id * fragments_per_warp < kWarpsToLoad * fragments_per_block) {
const int kNumThrPerRow = WARP_SIZE / FRAG_SIZE_N;
int sRow = lane_id / kNumThrPerRow;
int sWarpColStride = kNumThrPerRow * kElementsPerAccess;
int sWarpColOff = warp_id * sWarpColStride;
int sThrColOff = lane_id % kNumThrPerRow * kElementsPerAccess;
int sCol = sThrColOff + sWarpColOff;
int gRow = i + sRow;
int gBlockColOff = blockIdx.x * fragments_per_block * FRAG_SIZE_N;
int gWarpColOff_0 = (warp_id / kWarpsToLoad * fragments_per_warp + j) * FRAG_SIZE_N;
int gWarpColOff_1 = warp_id % kWarpsToLoad * sWarpColStride;
int gCol = gBlockColOff + gWarpColOff_0 + gWarpColOff_1 + sThrColOff;
if (gRow < num_experts && gCol < num_experts) {
int4* tokens_cnts_4i_ptr = (int4*)(tokens_cnts_ptr + (gRow + 1) * num_experts + gCol);
int4* smem_4i_ptr = (int4*)(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol);
*smem_4i_ptr = *tokens_cnts_4i_ptr;
}
}
__syncthreads();
if (warp_id * fragments_per_warp < kWarpsToLoad * fragments_per_block) {
if (warp_id % kWarpsToLoad == 0) {
for (int k = 0; k < FRAG_SIZE_M; k += (WARP_SIZE / FRAG_SIZE_N)) {
int sRow = lane_id / FRAG_SIZE_N + k;
int sThrColOff = lane_id % FRAG_SIZE_N;
int sCol = sThrColOff + (warp_id / kWarpsToLoad) * FRAG_SIZE_N;
int gBlockColOff = blockIdx.x * fragments_per_block * FRAG_SIZE_N;
int gWarpColOff_0 = (warp_id / kWarpsToLoad * fragments_per_warp + j) * FRAG_SIZE_N;
int gCol = gBlockColOff + gWarpColOff_0 + sThrColOff;
if (gCol < num_experts) {
atomicAdd(local_offsets + gCol + 1, *(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol));
// atomicAdd(tokens_cnts_ptr + gCol, *(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol));
}
}
}
}
__syncthreads();
__syncthreads();
} // end of j
} // end of i
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
if (threadIdx.x < num_experts) {
atomicAdd(tokens_cnts_ptr + threadIdx.x, *(local_offsets + threadIdx.x + 1));
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int expert_id = topk_ids[i];
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
}
__threadfence_system();
grid.sync();
if (tid < num_experts) {
*(local_offsets + tid + 1) = *(tokens_cnts_ptr + tid);
}
__syncthreads();
}
SGLANG_FORCE_INLINE_DEVICE_FUNC void parallel_unaligned_local_cumsum(
const int& tid, int* tokens_cnts_ptr /*dest*/, int32_t* local_offsets /*dest*/, int32_t* local_offsets_buf,
const int32_t (*shared_counts)[EXPERTS_PER_WARP] /*src*/, const int& experts_per_warp, const int32_t& num_experts,
cg::grid_group& grid) {
int active_threads = CEILDIV(num_experts, kElementsPerThr);
if (threadIdx.x == 0) {
local_offsets[0] = 0;
}
if (threadIdx.x < active_threads) {
for (int i = threadIdx.x * kElementsPerThr + 1; i < MIN((threadIdx.x + 1) * kElementsPerThr, num_experts) + 1;
++i) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
int expert_count = 0;
int warp_idx = (i - 1) / experts_per_warp;
int expert_offset = (i - 1) % experts_per_warp;
expert_count = shared_counts[warp_idx][expert_offset];
int expert_count = shared_counts[warp_idx][expert_offset];
int last_val = (i - 1) % kElementsPerThr == 0 ? 0 : local_offsets[i - 1];
local_offsets[i] = last_val + expert_count;
cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
}
local_offsets_buf[threadIdx.x] = local_offsets[MIN((threadIdx.x + 1) * kElementsPerThr, num_experts)];
*total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
if (threadIdx.x < active_threads && threadIdx.x > 0) {
int offset = 0;
for (int j = 0; j < threadIdx.x; ++j) {
offset += local_offsets_buf[j];
}
for (int i = threadIdx.x * kElementsPerThr + 1; i < MIN((threadIdx.x + 1) * kElementsPerThr, num_experts) + 1;
++i) {
local_offsets[i] += offset;
}
}
__syncthreads();
if (tid < num_experts) {
*(tokens_cnts_ptr + tid) = 0;
}
if (threadIdx.x < num_experts) {
*(tokens_cnts_ptr + (blockIdx.x + 1) * num_experts + threadIdx.x) = *(local_offsets + threadIdx.x + 1);
*(local_offsets + threadIdx.x + 1) = 0;
} else if (threadIdx.x < MAX_NUM_EXPERTS) {
*(local_offsets + threadIdx.x + 1) = 0;
}
__threadfence_system();
grid.sync();
}
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* __restrict__ tokens_cnts,
int32_t* __restrict__ cumsum, const int tokens_per_block,
const int tokens_per_thread, const int K) {
__shared__ int32_t smem[FRAG_SIZE_M * FRAG_SIZE_N * FRAGS_PER_BLOCK];
int32_t(*shared_counts)[EXPERTS_PER_WARP] = (int32_t(*)[EXPERTS_PER_WARP]) & smem[0];
__shared__ int32_t local_offsets[MAX_NUM_EXPERTS + 1];
__shared__ int32_t local_offsets_buf[CEILDIV(MAX_NUM_EXPERTS, kElementsPerThr)];
const int tid = threadIdx.x + blockDim.x * blockIdx.x;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int experts_per_warp = EXPERTS_PER_WARP;
int* tokens_cnts_ptr = &(tokens_cnts[0]);
int* smem_ptr = &(smem[0]);
cg::grid_group grid = cg::this_grid();
if (threadIdx.x < FRAG_SIZE_M * FRAG_SIZE_N) {
for (int i = 0; i < FRAG_SIZE_M * FRAG_SIZE_N * FRAGS_PER_BLOCK; i += FRAG_SIZE_M * FRAG_SIZE_N) {
smem[threadIdx.x + i] = 0;
}
}
__syncthreads();
const size_t start_idx = tokens_per_block * blockIdx.x + tokens_per_thread * threadIdx.x;
const size_t end_idx = start_idx + tokens_per_thread;
if (threadIdx.x * tokens_per_thread < tokens_per_block) {
for (int i = start_idx; i < MIN(numel, end_idx); ++i) {
int expert_id = topk_ids[i];
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
}
}
__syncthreads();
parallel_unaligned_local_cumsum(tid, tokens_cnts_ptr /*dest*/, local_offsets, local_offsets_buf, shared_counts,
experts_per_warp, num_experts, grid);
reduce_unaligned_cumsum(tokens_cnts_ptr /*src_and_dest*/, smem_ptr, local_offsets, tid, lane_id, warp_id, num_experts,
grid);
align_global_cumsum(local_offsets /*src_and_dest*/, local_offsets_buf, smem_ptr, tid, block_size, num_experts);
store_global_cumsum(cumsum /*dest*/, total_tokens_post_pad /*dest*/, local_offsets /*src*/, tid, num_experts, grid);
if (tid < num_experts) {
for (int i = local_offsets[tid]; i < local_offsets[tid + 1]; i += block_size) {
expert_ids[i / block_size] = tid;
}
}
__syncthreads();
if (threadIdx.x * tokens_per_thread < tokens_per_block) {
for (int i = start_idx; i < MIN(numel, end_idx); ++i) {
int32_t expert_id = topk_ids[i];
int32_t rank_post_pad = atomicAdd(&cumsum[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}
}
......@@ -302,29 +95,22 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel = moe_align_block_size_kernel<scalar_t>;
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
const int block_threads = 256;
const int num_blocks = MIN(CEILDIV(topk_ids.sizes()[0], block_threads), num_experts);
scalar_t* topk_ids_ptr = topk_ids.data_ptr<scalar_t>();
int32_t* sorted_token_ids_ptr = sorted_token_ids.data_ptr<int32_t>();
int32_t* experts_ids_ptr = experts_ids.data_ptr<int32_t>();
int32_t* num_tokens_post_pad_ptr = num_tokens_post_pad.data_ptr<int32_t>();
size_t num_tokens = topk_ids.numel();
int32_t* token_cnts_buffer_ptr = token_cnts_buffer.data_ptr<int32_t>();
int32_t* cumsum_buffer_ptr = cumsum_buffer.data_ptr<int32_t>();
int tokens_per_block = CEILDIV(topk_ids.sizes()[0], num_blocks) * topk_ids.sizes()[1];
int tokens_per_thread = CEILDIV(tokens_per_block, block_threads);
int K = topk_ids.sizes()[1];
void* kernelArgs[] = {&topk_ids_ptr, &sorted_token_ids_ptr, &experts_ids_ptr, &num_tokens_post_pad_ptr,
&num_experts, &block_size, &num_tokens, &token_cnts_buffer_ptr,
&cumsum_buffer_ptr, &tokens_per_block, &tokens_per_thread, &K};
cudaLaunchCooperativeKernel((void*)kernel, num_blocks, block_threads, kernelArgs);
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
});
}
......@@ -49,17 +49,6 @@ struct cuda_error : public std::runtime_error {
} \
} while (0)
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
template <typename T>
void check(T result, char const* const func, const char* const file, int const line) {
if (result) {
fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line, static_cast<unsigned int>(result),
cudaGetErrorString(result), func);
cudaDeviceReset();
exit(EXIT_FAILURE);
}
}
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \
......@@ -106,22 +95,3 @@ inline int getSMVersion() {
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize // 64
#endif
#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_cooperative_groups.h>
#include <hip/hip_runtime.h>
static __inline__ __host__ __device__ hipError_t cudaLaunchCooperativeKernel(const void* f, dim3 gridDim,
dim3 blockDimX, void** kernelParams) {
return hipLaunchCooperativeKernel(f, gridDim, blockDimX, kernelParams, 0, hipStreamDefault);
}
#endif
......@@ -171,12 +171,12 @@ def test_moe_align_block_size_compare_implementations(block_size, num_tokens, to
num_tokens_post_pad_cuda = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
)
token_cnts_buffer = torch.zeros(
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.zeros(
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment