Unverified Commit b10c64c8 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[ROCm][Bugfix][Model] Fix illegal memory access when running qwen3_moe models...


[ROCm][Bugfix][Model] Fix illegal memory access when running qwen3_moe models with  rms_norm (Qwen3-235B-A22B,  Qwen3-30B-A3B, etc.) (#26192)
Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
Signed-off-by: default avatarrasmith <Randall.Smith@amd.com>
Co-authored-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 0925b28a
...@@ -364,18 +364,26 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] ...@@ -364,18 +364,26 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK(weight.is_contiguous()); TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
int64_t input_stride = input.stride(-2); // We cannot just use `input.stride(-2)` if the tensor is not row-major.
// Instead, we use a 2d view to get the second-innermost stride.
// That way the dimensions (except the last one) can be arbitrarily permuted.
torch::Tensor input_view = input.view({-1, hidden_size});
int num_tokens = input_view.numel() / hidden_size;
int64_t input_stride = input_view.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { VLLM_DISPATCH_FLOATING_TYPES(
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( input_view.scalar_type(), "rms_norm_kernel", [&] {
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride, vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size); out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
}); input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size);
});
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment