Unverified Commit 9b9e8253 authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

[Fix]Fix index oob in get_group_gemm_starts kernel. (#8564)

parent 66a398f4
...@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts( ...@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts(
int* problem_sizes, int* problem_sizes,
int* problem_sizes_transpose, int* problem_sizes_transpose,
bool transpose = false) { bool transpose = false) {
int expert_id = threadIdx.x; int64_t expert_id = static_cast<int64_t>(threadIdx.x);
if (expert_id >= gridDim.x * blockDim.x) { if (expert_id >= gridDim.x * blockDim.x) {
return; return;
...@@ -46,11 +46,11 @@ __global__ void get_group_gemm_starts( ...@@ -46,11 +46,11 @@ __global__ void get_group_gemm_starts(
problem_sizes_transpose[expert_id * 3 + 2] = k; problem_sizes_transpose[expert_id * 3 + 2] = k;
} }
int32_t expert_offset = expert_offsets[expert_id]; int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int a_stride = 0; int64_t a_stride = 0;
int b_stride = 0; int64_t b_stride = 0;
int a_scale_stride = 0; int64_t a_scale_stride = 0;
int b_scale_stride = 0; int64_t b_scale_stride = 0;
if (!transpose) { if (!transpose) {
a_stride = expert_offset * k; a_stride = expert_offset * k;
b_stride = expert_id * k * n; b_stride = expert_id * k * n;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment