Unverified Commit 9556af87 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[torch.compile] Add support for non-contiguous fused RMSNorm + group quant (#36551)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarCopilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: default avatarProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
parent a1a3523a
...@@ -101,8 +101,8 @@ steps: ...@@ -101,8 +101,8 @@ steps:
- nvidia-smi - nvidia-smi
# Run all models and attn backends but only Inductor partition and native custom ops # Run all models and attn backends but only Inductor partition and native custom ops
- pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and not +quant_fp8" - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and not +quant_fp8"
# Qwen requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported # Qwen/Deepseek requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported
- pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and +quant_fp8 and qwen3" - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and not +rms_norm and +quant_fp8 and (qwen3 or deepseek)"
- label: Fusion E2E Config Sweep (H100) - label: Fusion E2E Config Sweep (H100)
timeout_in_minutes: 30 timeout_in_minutes: 30
...@@ -132,9 +132,9 @@ steps: ...@@ -132,9 +132,9 @@ steps:
commands: commands:
- nvidia-smi - nvidia-smi
# Run all models but only FLASHINFER, Inductor partition and native custom ops # Run all models but only FLASHINFER, Inductor partition and native custom ops
# Qwen requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported # Qwen/Deepseek requires +quant_fp8 as -quant_fp8 rms+quant fusion is not supported
# Run just llama3 (fp8 & fp4) for all config combinations (only inductor partition) # Run just llama3 (fp8 & fp4) for all config combinations (only inductor partition)
- pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and (FLASHINFER and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3) or llama-3)" - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "inductor_partition and (FLASHINFER and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek)) or llama-3)"
- label: Fusion E2E TP2 Quick (H100) - label: Fusion E2E TP2 Quick (H100)
timeout_in_minutes: 20 timeout_in_minutes: 20
...@@ -150,8 +150,8 @@ steps: ...@@ -150,8 +150,8 @@ steps:
commands: commands:
- nvidia-smi - nvidia-smi
# Run all models and attn backends but only Inductor partition and native custom ops # Run all models and attn backends but only Inductor partition and native custom ops
- pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "inductor_partition and not +rms_norm and not +quant_fp8" - pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))"
- pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "inductor_partition and not +rms_norm and not +quant_fp8" - pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))"
- label: Fusion E2E TP2 AR-RMS Config Sweep (H100) - label: Fusion E2E TP2 AR-RMS Config Sweep (H100)
timeout_in_minutes: 40 timeout_in_minutes: 40
...@@ -205,7 +205,7 @@ steps: ...@@ -205,7 +205,7 @@ steps:
commands: commands:
- nvidia-smi - nvidia-smi
# Run all models but only FLASHINFER, Inductor partition and native custom ops # Run all models but only FLASHINFER, Inductor partition and native custom ops
# include qwen with +quant_fp8 as -quant_fp8 rms+quant fusion is not supported # include qwen/deepseek with +quant_fp8 as -quant_fp8 rms+quant fusion is not supported
# for ar-rms-quant-fp4, also sweep llama3 # for ar-rms-quant-fp4, also sweep llama3
- pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "(FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3)) or Llama-3.1-8B-Instruct-FP4" - pytest -v -s tests/compile/fusions_e2e/test_tp2_ar_rms.py -k "(FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))) or Llama-3.1-8B-Instruct-FP4"
- pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and qwen3)" - pytest -v -s tests/compile/fusions_e2e/test_tp2_async_tp.py -k "FLASHINFER and inductor_partition and not +rms_norm and (not +quant_fp8 or +quant_fp8 and (qwen3 or deepseek))"
...@@ -15,31 +15,33 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( ...@@ -15,31 +15,33 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size] scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size, float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) { int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) {
float rms = 0.0f; float rms = 0.0f;
float token_scale = 0.0f; float token_scale = 0.0f;
// Compute rms // Compute rms
vllm::vectorized::compute_rms<scalar_t, has_residual>( vllm::vectorized::compute_rms<scalar_t, has_residual>(
&rms, input, hidden_size, var_epsilon, residual); &rms, input, hidden_size, input_stride, var_epsilon, residual);
// Compute scale // Compute scale
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
has_residual>( has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, hidden_size, &token_scale, scales, input, weight, rms, scale_ub, hidden_size,
residual); input_stride, residual);
// RMS Norm + Quant // RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) { if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
token_scale = 1.0f / token_scale; token_scale = 1.0f / token_scale;
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true, vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
has_residual>( has_residual>(out, input, weight, rms,
out, input, weight, rms, &token_scale, hidden_size, residual); &token_scale, hidden_size,
input_stride, residual);
} else { } else {
// FP8 - Do not invert token_scale for exact match with FBGemm // FP8 - Do not invert token_scale for exact match with FBGemm
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false, vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
has_residual>( has_residual>(out, input, weight, rms,
out, input, weight, rms, &token_scale, hidden_size, residual); &token_scale, hidden_size,
input_stride, residual);
} }
} }
...@@ -51,38 +53,40 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( ...@@ -51,38 +53,40 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size] scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size, float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) { int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) {
// For vectorization, token_input and token_output pointers need to be // For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively. // aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0; bool const can_vectorize = hidden_size % 4 == 0 and input_stride % 4 == 0;
if (can_vectorize) { if (can_vectorize) {
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t, return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
has_residual>( has_residual>(
out, scales, input, weight, scale_ub, var_epsilon, hidden_size, out, scales, input, weight, scale_ub, var_epsilon, hidden_size,
residual); input_stride, residual);
} }
float rms = 0.0f; float rms = 0.0f;
float token_scale = 0.0f; float token_scale = 0.0f;
// Compute RMS // Compute RMS
vllm::compute_rms<scalar_t, has_residual>(&rms, input, hidden_size, vllm::compute_rms<scalar_t, has_residual>(
var_epsilon, residual); &rms, input, hidden_size, input_stride, var_epsilon, residual);
// Compute Scale // Compute Scale
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>( vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, hidden_size, &token_scale, scales, input, weight, rms, scale_ub, hidden_size,
residual); input_stride, residual);
// RMS Norm + Quant // RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) { if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
token_scale = 1.0f / token_scale; token_scale = 1.0f / token_scale;
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>( vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
out, input, weight, rms, &token_scale, hidden_size, residual); out, input, weight, rms, &token_scale, hidden_size, input_stride,
residual);
} else { } else {
// FP8 - Do not invert s_token_scale for exact match with FBGemm // FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>( vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
out, input, weight, rms, &token_scale, hidden_size, residual); out, input, weight, rms, &token_scale, hidden_size, input_stride,
residual);
} }
} }
...@@ -97,19 +101,20 @@ __global__ void rms_norm_per_block_quant_kernel( ...@@ -97,19 +101,20 @@ __global__ void rms_norm_per_block_quant_kernel(
scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size] scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size, float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) { int32_t const input_stride, scalar_t* __restrict__ residual = nullptr,
int64_t outer_scale_stride = 1) {
float rms; float rms;
// Compute RMS // Compute RMS
// Always able to vectorize due to constraints on hidden_size // Always able to vectorize due to constraints on hidden_size
vllm::vectorized::compute_rms<scalar_t, has_residual>( vllm::vectorized::compute_rms<scalar_t, has_residual>(
&rms, input, hidden_size, var_epsilon, residual); &rms, input, hidden_size, input_stride, var_epsilon, residual);
// Compute Scale // Compute Scale
// Always able to vectorize due to constraints on hidden_size and group_size // Always able to vectorize due to constraints on hidden_size and group_size
vllm::vectorized::compute_dynamic_per_token_scales< vllm::vectorized::compute_dynamic_per_token_scales<
scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>( scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>(
nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual, nullptr, scales, input, weight, rms, scale_ub, hidden_size, input_stride,
outer_scale_stride); residual, outer_scale_stride);
// RMS Norm + Quant // RMS Norm + Quant
// Always able to vectorize due to constraints on hidden_size // Always able to vectorize due to constraints on hidden_size
...@@ -120,7 +125,7 @@ __global__ void rms_norm_per_block_quant_kernel( ...@@ -120,7 +125,7 @@ __global__ void rms_norm_per_block_quant_kernel(
vllm::vectorized::norm_and_quant< vllm::vectorized::norm_and_quant<
scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>, scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>,
has_residual, is_scale_transposed, group_size>( has_residual, is_scale_transposed, group_size>(
out, input, weight, rms, scales, hidden_size, residual, out, input, weight, rms, scales, hidden_size, input_stride, residual,
outer_scale_stride); outer_scale_stride);
} }
...@@ -137,6 +142,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( ...@@ -137,6 +142,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
std::optional<at::Tensor> const& scale_ub, std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual) { std::optional<at::Tensor>& residual) {
int32_t hidden_size = input.size(-1); int32_t hidden_size = input.size(-1);
int32_t input_stride = input.view({-1, hidden_size}).stride(0);
auto num_tokens = input.numel() / hidden_size; auto num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
...@@ -153,7 +159,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( ...@@ -153,7 +159,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(), out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(), input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr, scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, hidden_size, var_epsilon, hidden_size, input_stride,
has_residual ? residual->data_ptr<scalar_in_t>() : nullptr); has_residual ? residual->data_ptr<scalar_in_t>() : nullptr);
}); });
}); });
...@@ -170,7 +176,9 @@ void rms_norm_dynamic_per_token_quant( ...@@ -170,7 +176,9 @@ void rms_norm_dynamic_per_token_quant(
? c10::ScalarType::Float8_e4m3fn ? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz; : c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.stride(-1) == 1,
"Input must be contiguous in the last dimension");
if (scale_ub.has_value()) { if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type); TORCH_CHECK(out.dtype() == kFp8Type);
...@@ -179,6 +187,7 @@ void rms_norm_dynamic_per_token_quant( ...@@ -179,6 +187,7 @@ void rms_norm_dynamic_per_token_quant(
TORCH_CHECK(scales.dtype() == torch::kFloat32); TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) { if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type()); TORCH_CHECK(residual->scalar_type() == input.scalar_type());
TORCH_CHECK(residual->is_contiguous());
} }
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
...@@ -200,6 +209,15 @@ void rms_norm_per_block_quant_dispatch( ...@@ -200,6 +209,15 @@ void rms_norm_per_block_quant_dispatch(
std::optional<at::Tensor> const& scale_ub, std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual, bool is_scale_transposed) { std::optional<at::Tensor>& residual, bool is_scale_transposed) {
int32_t hidden_size = input.size(-1); int32_t hidden_size = input.size(-1);
int32_t input_stride = input.view({-1, hidden_size}).stride(0);
TORCH_CHECK(hidden_size % 4 == 0,
"Hidden size must be divisible by 4 for vectorized access");
TORCH_CHECK(input_stride % 4 == 0,
"Input stride must be divisible by 4 for vectorized access");
TORCH_CHECK(group_size % 4 == 0,
"Group size must be divisible by 4 for vectorized access");
auto num_tokens = input.numel() / hidden_size; auto num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
...@@ -225,7 +243,7 @@ void rms_norm_per_block_quant_dispatch( ...@@ -225,7 +243,7 @@ void rms_norm_per_block_quant_dispatch(
weight.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr, : nullptr,
var_epsilon, hidden_size, var_epsilon, hidden_size, input_stride,
has_residual ? residual->data_ptr<scalar_in_t>() has_residual ? residual->data_ptr<scalar_in_t>()
: nullptr, : nullptr,
scales.stride(1)); scales.stride(1));
...@@ -246,7 +264,9 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, ...@@ -246,7 +264,9 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
? c10::ScalarType::Float8_e4m3fn ? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz; : c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.stride(-1) == 1,
"Input must be contiguous in the last dimension");
if (scale_ub.has_value()) { if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type); TORCH_CHECK(out.dtype() == kFp8Type);
...@@ -255,6 +275,7 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, ...@@ -255,6 +275,7 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
TORCH_CHECK(scales.dtype() == torch::kFloat32); TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) { if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type()); TORCH_CHECK(residual->scalar_type() == input.scalar_type());
TORCH_CHECK(residual->is_contiguous());
} }
TORCH_CHECK(group_size == 128 || group_size == 64, TORCH_CHECK(group_size == 128 || group_size == 64,
......
...@@ -16,14 +16,17 @@ namespace vllm { ...@@ -16,14 +16,17 @@ namespace vllm {
// has_residual must be true, if residual is not a nullptr // has_residual must be true, if residual is not a nullptr
template <typename scalar_t, bool has_residual = false> template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
int32_t const hidden_size, float const epsilon, int32_t const hidden_size,
int32_t const input_stride, float const epsilon,
scalar_t const* __restrict__ residual = nullptr) { scalar_t const* __restrict__ residual = nullptr) {
int64_t const input_token_offset =
blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
// sum of squares // sum of squares
float ss = 0.0f; float ss = 0.0f;
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]); float x = static_cast<float>(input[input_token_offset + i]);
if constexpr (has_residual) { if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]); x += static_cast<float>(residual[token_offset + i]);
} }
...@@ -73,15 +76,20 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -73,15 +76,20 @@ __device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales, float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub, float const rms, float const* __restrict__ scale_ub,
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, int32_t const hidden_size, int32_t const input_stride,
scalar_t const* __restrict__ residual = nullptr,
int32_t const group_size = 0, int64_t outer_scale_stride = 1) { int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
float block_absmax_val_maybe = 0.0f; float block_absmax_val_maybe = 0.0f;
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>}; constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
__syncthreads(); __syncthreads();
int64_t const input_token_offset =
blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
if (group_size > 0) { if (group_size > 0) {
__shared__ float s_max_vals[1024];
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
int64_t num_groups = hidden_size / group_size; int64_t num_groups = hidden_size / group_size;
__shared__ float s_max_vals[1024];
int64_t const threads_per_group = blockDim.x / num_groups; int64_t const threads_per_group = blockDim.x / num_groups;
int64_t const thread_in_group = threadIdx.x % threads_per_group; int64_t const thread_in_group = threadIdx.x % threads_per_group;
int64_t const group_offset = threadIdx.x / threads_per_group * group_size; int64_t const group_offset = threadIdx.x / threads_per_group * group_size;
...@@ -89,7 +97,7 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -89,7 +97,7 @@ __device__ void compute_dynamic_per_token_scales(
int64_t const thread_end = int64_t const thread_end =
min(group_offset + group_size, static_cast<int64_t>(hidden_size)); min(group_offset + group_size, static_cast<int64_t>(hidden_size));
for (auto i = thread_offset; i < thread_end; i += threads_per_group) { for (auto i = thread_offset; i < thread_end; i += threads_per_group) {
float x = static_cast<float>(input[token_offset + i]); float x = static_cast<float>(input[input_token_offset + i]);
if constexpr (has_residual) { if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]); x += static_cast<float>(residual[token_offset + i]);
} }
...@@ -144,10 +152,8 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -144,10 +152,8 @@ __device__ void compute_dynamic_per_token_scales(
} }
__syncthreads(); __syncthreads();
} else { } else {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]); float x = static_cast<float>(input[input_token_offset + i]);
if constexpr (has_residual) { if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]); x += static_cast<float>(residual[token_offset + i]);
} }
...@@ -185,12 +191,15 @@ template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted, ...@@ -185,12 +191,15 @@ template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
__device__ void norm_and_quant( __device__ void norm_and_quant(
scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
scalar_t const* __restrict__ weight, float const rms, float* const scale, scalar_t const* __restrict__ weight, float const rms, float* const scale,
int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr, int32_t const hidden_size, int32_t const input_stride,
int32_t const group_size = 0, int64_t outer_scale_stride = 1) { scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0,
int64_t outer_scale_stride = 1) {
int64_t const input_token_offset =
blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]); float x = static_cast<float>(input[input_token_offset + i]);
if constexpr (has_residual) { if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]); x += static_cast<float>(residual[token_offset + i]);
residual[token_offset + i] = static_cast<scalar_t>(x); residual[token_offset + i] = static_cast<scalar_t>(x);
...@@ -224,13 +233,16 @@ namespace vectorized { ...@@ -224,13 +233,16 @@ namespace vectorized {
// hidden_size must be a multiple of 4 // hidden_size must be a multiple of 4
template <typename scalar_t, bool has_residual = false> template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
int32_t const hidden_size, float const epsilon, int32_t const hidden_size,
int32_t const input_stride, float const epsilon,
scalar_t const* __restrict__ residual = nullptr) { scalar_t const* __restrict__ residual = nullptr) {
int64_t const input_token_offset =
blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
// Vectorized input/output to better utilize memory bandwidth. // Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input = vec4_t<scalar_t> const* vec_input =
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]); reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
vec4_t<scalar_t> const* vec_residual = nullptr; vec4_t<scalar_t> const* vec_residual = nullptr;
if constexpr (has_residual) { if constexpr (has_residual) {
vec_residual = vec_residual =
...@@ -288,7 +300,8 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -288,7 +300,8 @@ __device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales, float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub, float const rms, float const* __restrict__ scale_ub,
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, int32_t const hidden_size, int32_t const input_stride,
scalar_t const* __restrict__ residual = nullptr,
int64_t outer_scale_stride = 1) { int64_t outer_scale_stride = 1) {
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>}; constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
...@@ -300,10 +313,13 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -300,10 +313,13 @@ __device__ void compute_dynamic_per_token_scales(
vec4_t<scalar_t> const* vec_weight = nullptr; vec4_t<scalar_t> const* vec_weight = nullptr;
vec4_t<scalar_t> const* vec_residual = nullptr; vec4_t<scalar_t> const* vec_residual = nullptr;
int64_t const input_token_offset =
blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
if constexpr (group_size > 0) { if constexpr (group_size > 0) {
__shared__ float s_max_vals[1024]; __shared__ float s_max_vals[1024];
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
int64_t const num_groups = hidden_size / group_size; int64_t const num_groups = hidden_size / group_size;
int64_t const threads_per_group = blockDim.x / num_groups; int64_t const threads_per_group = blockDim.x / num_groups;
int64_t const thread_in_group = threadIdx.x % threads_per_group; int64_t const thread_in_group = threadIdx.x % threads_per_group;
...@@ -312,7 +328,8 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -312,7 +328,8 @@ __device__ void compute_dynamic_per_token_scales(
int64_t const thread_offset = group_offset + thread_in_group; int64_t const thread_offset = group_offset + thread_in_group;
int64_t const thread_end = min(group_offset + (group_size >> 2), int64_t const thread_end = min(group_offset + (group_size >> 2),
static_cast<int64_t>(hidden_size >> 2)); static_cast<int64_t>(hidden_size >> 2));
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]); vec_input =
reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight); vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
if constexpr (has_residual) { if constexpr (has_residual) {
vec_residual = vec_residual =
...@@ -396,8 +413,8 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -396,8 +413,8 @@ __device__ void compute_dynamic_per_token_scales(
__syncthreads(); __syncthreads();
} else { } else {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); vec_input =
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]); reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight); vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
if constexpr (has_residual) { if constexpr (has_residual) {
vec_residual = vec_residual =
...@@ -462,18 +479,18 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -462,18 +479,18 @@ __device__ void compute_dynamic_per_token_scales(
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted, template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
bool has_residual = false, bool is_scale_transposed = false, bool has_residual = false, bool is_scale_transposed = false,
int32_t group_size = 0> int32_t group_size = 0>
__device__ void norm_and_quant(scalar_out_t* __restrict__ output, __device__ void norm_and_quant(
scalar_t const* __restrict__ input, scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
scalar_t const* __restrict__ weight, scalar_t const* __restrict__ weight, float const rms, float* const scale,
float const rms, float* const scale, int32_t const hidden_size, int32_t const input_stride,
int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) {
scalar_t* __restrict__ residual = nullptr, int64_t const input_token_offset =
int64_t outer_scale_stride = 1) { blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
// Vectorized input/output/weight/residual to better utilize memory bandwidth. // Vectorized input/output/weight/residual to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input = vec4_t<scalar_t> const* vec_input =
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]); reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
vec4_t<scalar_t> const* vec_weight = vec4_t<scalar_t> const* vec_weight =
reinterpret_cast<vec4_t<scalar_t> const*>(weight); reinterpret_cast<vec4_t<scalar_t> const*>(weight);
q8x4_t<scalar_out_t>* vec_output = q8x4_t<scalar_out_t>* vec_output =
......
...@@ -72,6 +72,16 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): ...@@ -72,6 +72,16 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
rocm_aiter_ops.refresh_env_variables() rocm_aiter_ops.refresh_env_variables()
# Filter here to reduce code duplication
requires_mla = "deepseek" in model_name.lower()
is_mla = "mla" in attn_backend.backend.name.lower()
if requires_mla != is_mla:
pytest.skip(
f"Incompatible model '{model_name}' and "
f"attention backend '{attn_backend.backend.name}'"
)
# Disable, compile cache to make sure custom passes run. # Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs. # Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
......
...@@ -44,6 +44,20 @@ ROCM_AITER_UNIFIED_ATTN = pytest.param( ...@@ -44,6 +44,20 @@ ROCM_AITER_UNIFIED_ATTN = pytest.param(
), ),
) )
FLASHINFER_MLA_ATTN = pytest.param(
AttentionBackendCase(backend=AttentionBackendEnum.FLASHINFER_MLA),
id="FLASHINFER_MLA",
marks=pytest.mark.skipif(
not is_blackwell() or not has_flashinfer(),
reason="FI backend requires Blackwell and FlashInfer",
),
)
TRITON_MLA_ATTN = pytest.param(
AttentionBackendCase(backend=AttentionBackendEnum.TRITON_MLA),
id="TRITON_MLA",
)
# Models # Models
llama3_8b = ModelFusionInfo( llama3_8b = ModelFusionInfo(
model_name="meta-llama/Llama-3.1-8B-Instruct", model_name="meta-llama/Llama-3.1-8B-Instruct",
...@@ -126,3 +140,25 @@ qwen3_a3b_fp8 = ModelFusionInfo( ...@@ -126,3 +140,25 @@ qwen3_a3b_fp8 = ModelFusionInfo(
async_tp=n_layers * 2, async_tp=n_layers * 2,
), ),
) )
deepseek_v3_fp8 = ModelFusionInfo(
model_name="deepseek-ai/DeepSeek-V3",
matches=lambda n_layers: Matches(
# 3 per dense layer (first 3):
# - input_rms + qkv_proj
# - q_a_layernorm + q_b_proj (inside MLA wrapper)
# - post_attn_layernorm + MLP
# 2 per MoE layer (remaining) due to MoE wrapping
rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers
# TODO silu+block quant
# act_quant_fusion=min(3, n_layers), # dense layers only
act_quant_fusion=0,
# MLA attn + quant not supported yet:
# https://github.com/vllm-project/vllm/issues/35792
attn_quant_fusion=0,
ar_rms_fusion=n_layers * 2 + 1,
# TODO
# sequence_parallel= n_layers * 2 + 1,
# async_tp=n_layers * 2,
),
)
...@@ -17,9 +17,12 @@ from .common import ( ...@@ -17,9 +17,12 @@ from .common import (
) )
from .models import ( from .models import (
FLASHINFER_ATTN, FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
ROCM_AITER_UNIFIED_ATTN, ROCM_AITER_UNIFIED_ATTN,
ROCM_ATTN, ROCM_ATTN,
TRITON_ATTN, TRITON_ATTN,
TRITON_MLA_ATTN,
deepseek_v3_fp8,
llama3_8b_fp4, llama3_8b_fp4,
llama3_8b_fp8, llama3_8b_fp8,
llama4_scout_fp4, llama4_scout_fp4,
...@@ -33,6 +36,9 @@ from .models import ( ...@@ -33,6 +36,9 @@ from .models import (
[ [
(*llama3_8b_fp8, False), (*llama3_8b_fp8, False),
(*qwen3_a3b_fp8, False), (*qwen3_a3b_fp8, False),
(*qwen3_a3b_fp8, True),
(*deepseek_v3_fp8, False),
(*deepseek_v3_fp8, True),
pytest.param( pytest.param(
*llama4_scout_fp8, *llama4_scout_fp8,
False, False,
...@@ -41,13 +47,6 @@ from .models import ( ...@@ -41,13 +47,6 @@ from .models import (
reason="Llama4 Scout FP8 only supported on CUDA", reason="Llama4 Scout FP8 only supported on CUDA",
), ),
), ),
pytest.param(
*qwen3_a3b_fp8,
True,
marks=pytest.mark.skipif(
not current_platform.is_cuda(), reason="DeepGemm only supported on CUDA"
),
),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -57,6 +56,8 @@ from .models import ( ...@@ -57,6 +56,8 @@ from .models import (
FLASHINFER_ATTN, FLASHINFER_ATTN,
ROCM_ATTN, ROCM_ATTN,
ROCM_AITER_UNIFIED_ATTN, ROCM_AITER_UNIFIED_ATTN,
FLASHINFER_MLA_ATTN,
TRITON_MLA_ATTN,
], ],
) )
@pytest.mark.parametrize("n_layers", [6]) @pytest.mark.parametrize("n_layers", [6])
...@@ -75,6 +76,9 @@ def test_tp1_fp8_fusions( ...@@ -75,6 +76,9 @@ def test_tp1_fp8_fusions(
run_e2e_fusion_test, run_e2e_fusion_test,
monkeypatch, monkeypatch,
): ):
if use_deepgemm and not current_platform.is_cuda():
pytest.skip("DeepGemm only supported on CUDA")
if use_deepgemm and is_flashinfer_fp8_blockscale_gemm_supported(): if use_deepgemm and is_flashinfer_fp8_blockscale_gemm_supported():
# Flashinfer block FP8 GEMM has internal quantization, so it can't # Flashinfer block FP8 GEMM has internal quantization, so it can't
# be fused with other ops. # be fused with other ops.
...@@ -86,7 +90,8 @@ def test_tp1_fp8_fusions( ...@@ -86,7 +90,8 @@ def test_tp1_fp8_fusions(
matches = matches_fn(n_layers) matches = matches_fn(n_layers)
if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops: block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower()
if block_fp8 and "-quant_fp8" in custom_ops:
# This is why config forces +quant_fp8 by default # This is why config forces +quant_fp8 by default
pytest.skip("native QuantFP8 matching not supported for group quant") pytest.skip("native QuantFP8 matching not supported for group quant")
......
...@@ -17,7 +17,9 @@ from .common import ( ...@@ -17,7 +17,9 @@ from .common import (
) )
from .models import ( from .models import (
FLASHINFER_ATTN, FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
TRITON_ATTN, TRITON_ATTN,
deepseek_v3_fp8,
llama3_8b, llama3_8b,
llama3_8b_fp4, llama3_8b_fp4,
llama3_8b_fp8, llama3_8b_fp8,
...@@ -33,10 +35,12 @@ pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only tes ...@@ -33,10 +35,12 @@ pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only tes
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides", "model_name, matches_fn, model_kwargs, hf_overrides",
# qwen3-fp8 should still fuse AR+rms even though group quant is not yet supported # qwen3 & dsv3 should still fuse AR+rms even though group quant is not yet supported
[llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8], [llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8, deepseek_v3_fp8],
)
@pytest.mark.parametrize(
"attn_backend", [TRITON_ATTN, FLASHINFER_ATTN, FLASHINFER_MLA_ATTN]
) )
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [4]) @pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm")) @pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
...@@ -54,7 +58,8 @@ def test_tp2_ar_rms_fp8_fusions( ...@@ -54,7 +58,8 @@ def test_tp2_ar_rms_fp8_fusions(
): ):
matches = matches_fn(n_layers) matches = matches_fn(n_layers)
if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops: block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower()
if block_fp8 and "-quant_fp8" in custom_ops:
# This is why config forces +quant_fp8 by default # This is why config forces +quant_fp8 by default
pytest.skip("native QuantFP8 matching not supported for group quant") pytest.skip("native QuantFP8 matching not supported for group quant")
......
...@@ -162,6 +162,7 @@ def ops_impl( ...@@ -162,6 +162,7 @@ def ops_impl(
) )
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("strided_input", [False, True])
@torch.inference_mode() @torch.inference_mode()
def test_rms_norm( def test_rms_norm(
default_vllm_config, default_vllm_config,
...@@ -175,6 +176,7 @@ def test_rms_norm( ...@@ -175,6 +176,7 @@ def test_rms_norm(
tma_alignment: int, tma_alignment: int,
seed: int, seed: int,
device: str, device: str,
strided_input: bool,
) -> None: ) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -184,17 +186,17 @@ def test_rms_norm( ...@@ -184,17 +186,17 @@ def test_rms_norm(
if group_size is not None and hidden_size % group_size[1] != 0: if group_size is not None and hidden_size % group_size[1] != 0:
# skip # skip
return pytest.skip("Skip non-divisible group sizes")
if group_size is not None and has_scale_ub: if group_size is not None and has_scale_ub:
# blockwise baseline doesn't support scale_ub # blockwise baseline doesn't support scale_ub
return pytest.skip("scale_ub not supported for blockwise/group quantization")
if ( if (
group_size is None or quant_dtype != current_platform.fp8_dtype() group_size is None or quant_dtype != current_platform.fp8_dtype()
) and tma_alignment != 0: ) and tma_alignment != 0:
# TMA alignment is only supported for groupwise fp8 kernels # TMA alignment is only supported for groupwise fp8 kernels
return pytest.skip("tma alignment not supported for per-token or int8 quantization")
if ( if (
group_size is not None group_size is not None
...@@ -202,21 +204,36 @@ def test_rms_norm( ...@@ -202,21 +204,36 @@ def test_rms_norm(
and hidden_size // group_size[1] % tma_alignment == 0 and hidden_size // group_size[1] % tma_alignment == 0
): ):
# Skip tests where TMA alignment doesn't create extra padding to save time # Skip tests where TMA alignment doesn't create extra padding to save time
return pytest.skip("Skip TMA alignment cases where no extra padding is added")
if has_scale_ub and quant_dtype != current_platform.fp8_dtype(): if has_scale_ub and quant_dtype != current_platform.fp8_dtype():
# skip # skip
return pytest.skip("scale_ub only supported for fp8 quantization")
layer = RMSNorm(hidden_size, EPS).to(dtype=dtype) layer = RMSNorm(hidden_size, EPS).to(dtype=dtype)
# Make weights # Make weights
layer.weight.data.normal_(mean=1.0, std=0.1) layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs # Make inputs: use a wider tensor and slice to create a non-contiguous
# (strided) input when strided_input=True. The last dimension stride
# remains 1, which the kernel requires.
scale = 1 / (hidden_size) scale = 1 / (hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale last_dim = 2 * hidden_size if strided_input else hidden_size
residual = torch.randn_like(x) * scale if add_residual else None x = torch.randn(num_tokens, last_dim, dtype=dtype) * scale
x = x[:, :hidden_size]
# dim 1 gets special-cased
x_is_strided = strided_input and num_tokens != 1
# check that the input is strided iff we expect it to be
assert x.is_contiguous() != x_is_strided
# Residual must still be contiguous
residual = (
torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
if add_residual
else None
)
if has_scale_ub: if has_scale_ub:
rms_x, _ = ref_rms_norm(layer, x, residual) rms_x, _ = ref_rms_norm(layer, x, residual)
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda") scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda")
...@@ -260,12 +277,33 @@ def test_rms_norm( ...@@ -260,12 +277,33 @@ def test_rms_norm(
if add_residual: if add_residual:
assert torch.allclose(ref_residual, ops_residual) assert torch.allclose(ref_residual, ops_residual)
output = torch.empty_like(x, dtype=quant_dtype) output = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
scales = torch.empty( scales = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
) )
opcheck( if group_size is None:
torch.ops._C.rms_norm_dynamic_per_token_quant, opcheck(
(output, x, layer.weight, scales, 1e-5, scale_ub, residual), torch.ops._C.rms_norm_dynamic_per_token_quant,
) (output, x, layer.weight, scales, 1e-5, scale_ub, residual),
)
else:
# TODO(luka/eliza) opcheck is broken?
# Somehow the cloned args are getting mutated in-place,
# which causes the opcheck to fail.
# https://github.com/vllm-project/vllm/issues/36688
return
opcheck(
torch.ops._C.rms_norm_per_block_quant,
(
output,
x,
layer.weight,
scales,
1e-5,
scale_ub,
residual,
group_size[1],
True, # is_scale_transposed
),
)
...@@ -427,7 +427,7 @@ def rms_norm_dynamic_per_token_quant( ...@@ -427,7 +427,7 @@ def rms_norm_dynamic_per_token_quant(
scale_ub: torch.Tensor | None = None, scale_ub: torch.Tensor | None = None,
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=quant_dtype) output = torch.empty(input.shape, dtype=quant_dtype, device=input.device)
scales = torch.empty( scales = torch.empty(
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
) )
...@@ -451,7 +451,7 @@ def rms_norm_per_block_quant( ...@@ -451,7 +451,7 @@ def rms_norm_per_block_quant(
tma_alignment: int = 0, tma_alignment: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert len(group_size) == 2 assert len(group_size) == 2
output = torch.empty_like(input, dtype=quant_dtype) output = torch.empty(input.shape, dtype=quant_dtype, device=input.device)
if is_scale_transposed: if is_scale_transposed:
if tma_alignment == 0: if tma_alignment == 0:
scales = torch.empty( scales = torch.empty(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment