Unverified Commit a88b3be7 authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Bugfix] Fix quant RMS norm fusion for quantization with TMA-aligned scales (#33255)


Signed-off-by: default avatarElizaWszola <ewszola@redhat.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent a49ea5a5
...@@ -315,7 +315,9 @@ void silu_and_mul_scaled_fp4_experts_quant( ...@@ -315,7 +315,9 @@ void silu_and_mul_scaled_fp4_experts_quant(
void per_token_group_quant_fp8(const torch::Tensor& input, void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s, torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min, int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0); double fp8_max, bool scale_ue8m0,
bool dummy_is_scale_transposed,
bool dummy_is_tma_aligned);
void per_token_group_quant_int8(const torch::Tensor& input, void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_q,
......
...@@ -97,7 +97,7 @@ __global__ void rms_norm_per_block_quant_kernel( ...@@ -97,7 +97,7 @@ __global__ void rms_norm_per_block_quant_kernel(
scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size] scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size, float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) { scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) {
float rms; float rms;
// Compute RMS // Compute RMS
// Always able to vectorize due to constraints on hidden_size // Always able to vectorize due to constraints on hidden_size
...@@ -108,7 +108,8 @@ __global__ void rms_norm_per_block_quant_kernel( ...@@ -108,7 +108,8 @@ __global__ void rms_norm_per_block_quant_kernel(
// Always able to vectorize due to constraints on hidden_size and group_size // Always able to vectorize due to constraints on hidden_size and group_size
vllm::vectorized::compute_dynamic_per_token_scales< vllm::vectorized::compute_dynamic_per_token_scales<
scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>( scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>(
nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual); nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual,
outer_scale_stride);
// RMS Norm + Quant // RMS Norm + Quant
// Always able to vectorize due to constraints on hidden_size // Always able to vectorize due to constraints on hidden_size
...@@ -119,7 +120,8 @@ __global__ void rms_norm_per_block_quant_kernel( ...@@ -119,7 +120,8 @@ __global__ void rms_norm_per_block_quant_kernel(
vllm::vectorized::norm_and_quant< vllm::vectorized::norm_and_quant<
scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>, scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>,
has_residual, is_scale_transposed, group_size>( has_residual, is_scale_transposed, group_size>(
out, input, weight, rms, scales, hidden_size, residual); out, input, weight, rms, scales, hidden_size, residual,
outer_scale_stride);
} }
} // namespace vllm } // namespace vllm
...@@ -225,7 +227,8 @@ void rms_norm_per_block_quant_dispatch( ...@@ -225,7 +227,8 @@ void rms_norm_per_block_quant_dispatch(
: nullptr, : nullptr,
var_epsilon, hidden_size, var_epsilon, hidden_size,
has_residual ? residual->data_ptr<scalar_in_t>() has_residual ? residual->data_ptr<scalar_in_t>()
: nullptr); : nullptr,
scales.stride(1));
}); });
}); });
}); });
...@@ -257,6 +260,11 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, ...@@ -257,6 +260,11 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
TORCH_CHECK(group_size == 128 || group_size == 64, TORCH_CHECK(group_size == 128 || group_size == 64,
"Unsupported group size: ", group_size); "Unsupported group size: ", group_size);
if (scales.stride(1) > 1) {
TORCH_CHECK(is_scale_transposed,
"Outer scale stride must be 1 when scales are not transposed");
}
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size, rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
var_epsilon, scale_ub, residual, var_epsilon, scale_ub, residual,
is_scale_transposed); is_scale_transposed);
......
...@@ -74,7 +74,7 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -74,7 +74,7 @@ __device__ void compute_dynamic_per_token_scales(
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub, float const rms, float const* __restrict__ scale_ub,
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
int32_t const group_size = 0) { int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
float block_absmax_val_maybe = 0.0f; float block_absmax_val_maybe = 0.0f;
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>}; constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
__syncthreads(); __syncthreads();
...@@ -133,7 +133,9 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -133,7 +133,9 @@ __device__ void compute_dynamic_per_token_scales(
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val()); scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
// Global output store // Global output store
if constexpr (is_scale_transposed) { if constexpr (is_scale_transposed) {
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
blockIdx.x] = scale; blockIdx.x] = scale;
} else { } else {
all_token_scales[blockIdx.x * num_groups + all_token_scales[blockIdx.x * num_groups +
...@@ -180,13 +182,11 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -180,13 +182,11 @@ __device__ void compute_dynamic_per_token_scales(
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted, template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
bool has_residual = false, bool is_scale_transposed = false> bool has_residual = false, bool is_scale_transposed = false>
__device__ void norm_and_quant(scalar_out_t* __restrict__ output, __device__ void norm_and_quant(
scalar_t const* __restrict__ input, scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
scalar_t const* __restrict__ weight, scalar_t const* __restrict__ weight, float const rms, float* const scale,
float const rms, float* const scale, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr,
int32_t const hidden_size, int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
scalar_t* __restrict__ residual = nullptr,
int32_t const group_size = 0) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
...@@ -202,7 +202,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, ...@@ -202,7 +202,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
int64_t scale_idx = 0; int64_t scale_idx = 0;
if (group_size > 0) { if (group_size > 0) {
if constexpr (is_scale_transposed) { if constexpr (is_scale_transposed) {
scale_idx = (i / group_size) * gridDim.x + blockIdx.x; int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
scale_idx = (i / group_size) * scale_rows + blockIdx.x;
} else { } else {
scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size; scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size;
} }
...@@ -286,8 +288,8 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -286,8 +288,8 @@ __device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales, float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub, float const rms, float const* __restrict__ scale_ub,
int32_t const hidden_size, int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
scalar_t const* __restrict__ residual = nullptr) { int64_t outer_scale_stride = 1) {
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>}; constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
const int VEC_SIZE = 4; const int VEC_SIZE = 4;
...@@ -382,7 +384,9 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -382,7 +384,9 @@ __device__ void compute_dynamic_per_token_scales(
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val()); scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
// Global output store // Global output store
if constexpr (is_scale_transposed) { if constexpr (is_scale_transposed) {
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
blockIdx.x] = scale; blockIdx.x] = scale;
} else { } else {
all_token_scales[blockIdx.x * num_groups + all_token_scales[blockIdx.x * num_groups +
...@@ -463,7 +467,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, ...@@ -463,7 +467,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
scalar_t const* __restrict__ weight, scalar_t const* __restrict__ weight,
float const rms, float* const scale, float const rms, float* const scale,
int32_t const hidden_size, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) { scalar_t* __restrict__ residual = nullptr,
int64_t outer_scale_stride = 1) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
// Vectorized input/output/weight/residual to better utilize memory bandwidth. // Vectorized input/output/weight/residual to better utilize memory bandwidth.
...@@ -516,7 +521,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, ...@@ -516,7 +521,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
int64_t const num_groups = hidden_size / group_size; int64_t const num_groups = hidden_size / group_size;
int64_t scale_idx = 0; int64_t scale_idx = 0;
if constexpr (is_scale_transposed) { if constexpr (is_scale_transposed) {
scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x; int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
scale_idx = (i * VEC_SIZE / group_size) * scale_rows + blockIdx.x;
} else { } else {
scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size; scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size;
} }
......
...@@ -379,7 +379,9 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input, ...@@ -379,7 +379,9 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
void per_token_group_quant_fp8(const torch::Tensor& input, void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s, torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min, int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0) { double fp8_max, bool scale_ue8m0,
bool dummy_is_scale_transposed = false,
bool dummy_is_tma_aligned = false) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
fp8_min, fp8_max, scale_ue8m0); fp8_min, fp8_max, scale_ue8m0);
} }
\ No newline at end of file
...@@ -643,11 +643,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -643,11 +643,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#ifndef USE_ROCM #ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor. // Compute per-token-group FP8 quantized tensor and scaling factor.
// The dummy arguments are here so we can correctly fuse with RMSNorm.
ops.def( ops.def(
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! " "per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
"output_s, " "output_s, "
"int group_size, float eps, float fp8_min, float fp8_max, bool " "int group_size, float eps, float fp8_min, float fp8_max, bool "
"scale_ue8m0) -> ()"); "scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned "
") -> ()");
ops.impl("per_token_group_fp8_quant", torch::kCUDA, ops.impl("per_token_group_fp8_quant", torch::kCUDA,
&per_token_group_quant_fp8); &per_token_group_quant_fp8);
......
...@@ -50,10 +50,9 @@ def test_tp1_fp8_fusions( ...@@ -50,10 +50,9 @@ def test_tp1_fp8_fusions(
run_e2e_fusion_test, run_e2e_fusion_test,
monkeypatch, monkeypatch,
): ):
if use_deepgemm: if use_deepgemm and is_blackwell():
# TODO(luka/eliza) DeepGEMM uses different quants, matching not supported # TODO(luka) DeepGEMM uses different quants, matching not supported
# - on Blackwell, uses a special quant fp8, currently not supported # - on Blackwell, uses a special quant fp8, currently not supported
# - on Hopper, tma-aligned scales inhibit matching (fix WIP)
pytest.skip("DeepGEMM & quant matching not currently supported") pytest.skip("DeepGEMM & quant matching not currently supported")
matches = matches_fn(n_layers) matches = matches_fn(n_layers)
...@@ -66,7 +65,6 @@ def test_tp1_fp8_fusions( ...@@ -66,7 +65,6 @@ def test_tp1_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","), custom_ops=custom_ops.split(","),
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest import pytest
import torch import torch
...@@ -21,7 +23,7 @@ QUANT_DTYPES = [torch.int8, current_platform.fp8_dtype()] ...@@ -21,7 +23,7 @@ QUANT_DTYPES = [torch.int8, current_platform.fp8_dtype()]
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029] VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
# Avoid combinatorial explosion with full Cartesian product # Avoid combinatorial explosion with full Cartesian product
NUM_TOKENS_HIDDEN_SIZES = [ NUM_TOKENS_HIDDEN_SIZES = [
*[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]], *[(1, i) for i in [1, 64, 128, *VEC_HIDDEN_SIZES, 5120, 5137]],
*[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]], *[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]],
*[(4096, i) for i in [1, 64, 5137]], *[(4096, i) for i in [1, 64, 5137]],
] ]
...@@ -29,6 +31,7 @@ NUM_TOKENS_HIDDEN_SIZES = [ ...@@ -29,6 +31,7 @@ NUM_TOKENS_HIDDEN_SIZES = [
ADD_RESIDUAL = [False, True] ADD_RESIDUAL = [False, True]
SCALE_UBS = [True, False] SCALE_UBS = [True, False]
GROUP_SIZES = [None, [1, 64], [1, 128]] GROUP_SIZES = [None, [1, 64], [1, 128]]
TMA_ALIGNMENTS = [0, 4]
SEEDS = [0] SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
...@@ -110,12 +113,21 @@ def ops_dynamic_per_token_or_block_quant( ...@@ -110,12 +113,21 @@ def ops_dynamic_per_token_or_block_quant(
residual: torch.Tensor | None, residual: torch.Tensor | None,
scale_ub: torch.Tensor | None, scale_ub: torch.Tensor | None,
group_size: list[int] | None, group_size: list[int] | None,
tma_alignment: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if residual is not None: if residual is not None:
residual = residual.clone() residual = residual.clone()
if group_size is not None: if group_size is not None:
out, scales = ops.rms_norm_per_block_quant( out, scales = ops.rms_norm_per_block_quant(
x, weight, EPS, quant_dtype, group_size, scale_ub, residual, True x,
weight,
EPS,
quant_dtype,
group_size,
scale_ub,
residual,
True,
tma_alignment,
) )
scales = scales.contiguous() scales = scales.contiguous()
else: else:
...@@ -132,9 +144,10 @@ def ops_impl( ...@@ -132,9 +144,10 @@ def ops_impl(
residual: torch.Tensor | None, residual: torch.Tensor | None,
scale_ub: torch.Tensor | None, scale_ub: torch.Tensor | None,
group_size: list[int] | None, group_size: list[int] | None,
tma_alignment: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
return ops_dynamic_per_token_or_block_quant( return ops_dynamic_per_token_or_block_quant(
weight, x, quant_dtype, residual, scale_ub, group_size weight, x, quant_dtype, residual, scale_ub, group_size, tma_alignment
) )
...@@ -143,7 +156,10 @@ def ops_impl( ...@@ -143,7 +156,10 @@ def ops_impl(
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS) @pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) @pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("group_size", GROUP_SIZES) @pytest.mark.parametrize(
"group_size, tma_alignment",
[(None, 0), *itertools.product(GROUP_SIZES, TMA_ALIGNMENTS)],
)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
...@@ -156,6 +172,7 @@ def test_rms_norm( ...@@ -156,6 +172,7 @@ def test_rms_norm(
dtype: torch.dtype, dtype: torch.dtype,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_size: list[int] | None, group_size: list[int] | None,
tma_alignment: int,
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
...@@ -173,6 +190,20 @@ def test_rms_norm( ...@@ -173,6 +190,20 @@ def test_rms_norm(
# blockwise baseline doesn't support scale_ub # blockwise baseline doesn't support scale_ub
return return
if (
group_size is None or quant_dtype != current_platform.fp8_dtype()
) and tma_alignment != 0:
# TMA alignment is only supported for groupwise fp8 kernels
return
if (
group_size is not None
and tma_alignment != 0
and hidden_size // group_size[1] % tma_alignment == 0
):
# Skip tests where TMA alignment doesn't create extra padding to save time
return
if has_scale_ub and quant_dtype != current_platform.fp8_dtype(): if has_scale_ub and quant_dtype != current_platform.fp8_dtype():
# skip # skip
return return
...@@ -196,7 +227,7 @@ def test_rms_norm( ...@@ -196,7 +227,7 @@ def test_rms_norm(
layer, x, quant_dtype, residual, scale_ub, group_size layer, x, quant_dtype, residual, scale_ub, group_size
) )
ops_out, ops_scales, ops_residual = ops_impl( ops_out, ops_scales, ops_residual = ops_impl(
layer.weight, x, quant_dtype, residual, scale_ub, group_size layer.weight, x, quant_dtype, residual, scale_ub, group_size, tma_alignment
) )
assert ref_out.dtype == quant_dtype assert ref_out.dtype == quant_dtype
......
...@@ -450,15 +450,30 @@ def rms_norm_per_block_quant( ...@@ -450,15 +450,30 @@ def rms_norm_per_block_quant(
scale_ub: torch.Tensor | None = None, scale_ub: torch.Tensor | None = None,
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
is_scale_transposed: bool = False, is_scale_transposed: bool = False,
tma_alignment: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert len(group_size) == 2 assert len(group_size) == 2
output = torch.empty_like(input, dtype=quant_dtype) output = torch.empty_like(input, dtype=quant_dtype)
if is_scale_transposed: if is_scale_transposed:
if tma_alignment == 0:
scales = torch.empty( scales = torch.empty(
(input.shape[-1] // group_size[1], input.numel() // input.shape[-1]), (input.shape[-1] // group_size[1], input.numel() // input.shape[-1]),
device=input.device, device=input.device,
dtype=torch.float32, dtype=torch.float32,
).transpose(0, 1) ).transpose(0, 1)
else:
m = input.shape[-2]
sf_k = input.shape[-1] // group_size[1]
tma_aligned_m = (m + tma_alignment - 1) // tma_alignment * tma_alignment
shape = input.shape[:-2] + (m, sf_k)
stride = (
(1, tma_aligned_m)
if input.dim() == 2
else (tma_aligned_m * sf_k, 1, tma_aligned_m)
)
scales = torch.empty_strided(
shape, stride, device=input.device, dtype=torch.float32
)
else: else:
scales = torch.empty( scales = torch.empty(
(input.numel() // input.shape[-1], input.shape[-1] // group_size[1]), (input.numel() // input.shape[-1], input.shape[-1] // group_size[1]),
...@@ -466,6 +481,10 @@ def rms_norm_per_block_quant( ...@@ -466,6 +481,10 @@ def rms_norm_per_block_quant(
dtype=torch.float32, dtype=torch.float32,
) )
assert tma_alignment in [0, 4], "Expected TMA alignment 0 or 4, but got " + str(
tma_alignment
)
torch.ops._C.rms_norm_per_block_quant( torch.ops._C.rms_norm_per_block_quant(
output, output,
input, input,
......
...@@ -292,6 +292,7 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -292,6 +292,7 @@ class MatcherQuantFP8(MatcherCustomOp):
has_col_major_scales: bool = False, has_col_major_scales: bool = False,
is_e8m0: bool = False, is_e8m0: bool = False,
match_rocm_aiter: bool = False, match_rocm_aiter: bool = False,
is_tma_aligned: bool = False,
) -> None: ) -> None:
if enabled is None: if enabled is None:
enabled = QuantFP8.enabled() enabled = QuantFP8.enabled()
...@@ -301,6 +302,7 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -301,6 +302,7 @@ class MatcherQuantFP8(MatcherCustomOp):
self.has_col_major_scales = has_col_major_scales self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0 self.is_e8m0 = is_e8m0
self.match_rocm_aiter = match_rocm_aiter self.match_rocm_aiter = match_rocm_aiter
self.is_tma_aligned = is_tma_aligned
if match_rocm_aiter: if match_rocm_aiter:
assert not quant_key.scale.group_shape.is_per_tensor(), ( assert not quant_key.scale.group_shape.is_per_tensor(), (
...@@ -336,6 +338,7 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -336,6 +338,7 @@ class MatcherQuantFP8(MatcherCustomOp):
quant_key.scale.group_shape, quant_key.scale.group_shape,
column_major_scales=has_col_major_scales, column_major_scales=has_col_major_scales,
use_ue8m0=is_e8m0, use_ue8m0=is_e8m0,
tma_aligned_scales=self.is_tma_aligned,
compile_native=False, compile_native=False,
) )
...@@ -367,6 +370,9 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -367,6 +370,9 @@ class MatcherQuantFP8(MatcherCustomOp):
) )
if self.quant_key.scale.group_shape.is_per_group(): if self.quant_key.scale.group_shape.is_per_group():
# for tma_aligned, the scale must be passed to forward_custom
# tma_aligned fusion then matches by custom op arguments
if not self.is_tma_aligned:
assert scale is None assert scale is None
scale = self.make_scale(input, transposed=self.has_col_major_scales) scale = self.make_scale(input, transposed=self.has_col_major_scales)
...@@ -384,6 +390,8 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -384,6 +390,8 @@ class MatcherQuantFP8(MatcherCustomOp):
fp8_min=fp8_min, fp8_min=fp8_min,
fp8_max=fp8_max, fp8_max=fp8_max,
scale_ue8m0=self.is_e8m0, scale_ue8m0=self.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
) )
return result, scale return result, scale
......
...@@ -121,6 +121,7 @@ class RMSNormQuantPattern: ...@@ -121,6 +121,7 @@ class RMSNormQuantPattern:
key: FusedRMSQuantKey, key: FusedRMSQuantKey,
has_col_major_scales: bool = False, has_col_major_scales: bool = False,
is_e8m0: bool = False, is_e8m0: bool = False,
is_tma_aligned: bool = False,
) -> None: ) -> None:
self.epsilon = epsilon self.epsilon = epsilon
self.quant_dtype = key.quant.dtype self.quant_dtype = key.quant.dtype
...@@ -136,7 +137,10 @@ class RMSNormQuantPattern: ...@@ -136,7 +137,10 @@ class RMSNormQuantPattern:
else MatcherFusedAddRMSNorm(epsilon) else MatcherFusedAddRMSNorm(epsilon)
) )
self.quant_matcher = MatcherQuantFP8( self.quant_matcher = MatcherQuantFP8(
key.quant, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 key.quant,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
) )
...@@ -262,8 +266,9 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -262,8 +266,9 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape, group_shape: GroupShape,
symmetric: bool = True, symmetric: bool = True,
has_col_major_scales: bool = False,
is_e8m0: bool = False, is_e8m0: bool = False,
has_col_major_scales: bool = True,
is_tma_aligned: bool = True,
) -> None: ) -> None:
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
...@@ -271,29 +276,63 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -271,29 +276,63 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
) )
self.group_shape = group_shape self.group_shape = group_shape
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0 self.is_e8m0 = is_e8m0
self.has_col_major_scales = has_col_major_scales
self.is_tma_aligned = is_tma_aligned
super().__init__( super().__init__(
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 epsilon,
key,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
) )
def register(self, pm_pass: PatternMatcherPass) -> None: def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual) result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms) result = torch.empty(
result_rms.shape,
device=result_rms.device,
dtype=self.quant_matcher.quant_key.dtype,
)
assert scale is not None
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.quant_matcher.QUANT_OP,
input=result_rms,
output_q=result,
output_s=scale,
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.quant_matcher.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, residual, scale return result, residual, scale
def replacement( def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be # In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe. # optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype) input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype) result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input, self.has_col_major_scales)
at = auto_functionalized( at = auto_functionalized(
self.FUSED_OP, self.FUSED_OP,
result=result, result=result,
...@@ -310,10 +349,12 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -310,10 +349,12 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
# result, residual, scale # result, residual, scale
return at[1], at[3], at[2] return at[1], at[3], at[2]
scale = self.quant_matcher.empty_f32(1, 1)
pm.register_replacement( pm.register_replacement(
pattern, pattern,
replacement, replacement,
self.rmsnorm_matcher.inputs(), self.rmsnorm_matcher.inputs() + [scale],
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
) )
...@@ -326,8 +367,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -326,8 +367,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape, group_shape: GroupShape,
symmetric: bool = True, symmetric: bool = True,
has_col_major_scales: bool = False,
is_e8m0: bool = False, is_e8m0: bool = False,
has_col_major_scales: bool = True,
is_tma_aligned: bool = True,
) -> None: ) -> None:
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
...@@ -335,29 +377,55 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -335,29 +377,55 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
) )
self.group_shape = group_shape self.group_shape = group_shape
self.has_col_major_scales = has_col_major_scales
self.is_tma_aligned = is_tma_aligned
super().__init__( super().__init__(
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 epsilon,
key,
has_col_major_scales=self.has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
) )
def register(self, pm_pass: PatternMatcherPass) -> None: def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, weight: torch.Tensor input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight) result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms) result = torch.empty(
result_rms.shape,
device=result_rms.device,
dtype=self.quant_matcher.quant_key.dtype,
)
assert scale is not None
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
fp8_min = finfo.min
fp8_max = finfo.max
_, result, scale = auto_functionalized(
self.quant_matcher.QUANT_OP,
input=result_rms,
output_q=result,
output_s=scale,
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.quant_matcher.is_e8m0,
dummy_is_scale_transposed=self.has_col_major_scales,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, scale return result, scale
def replacement( def replacement(
input: torch.Tensor, weight: torch.Tensor input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be # In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe. # optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype) input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype) result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.has_col_major_scales
)
at = auto_functionalized( at = auto_functionalized(
self.FUSED_OP, self.FUSED_OP,
result=result, result=result,
...@@ -368,16 +436,18 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -368,16 +436,18 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub=None, scale_ub=None,
residual=None, residual=None,
group_size=self.group_shape[1], group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.has_col_major_scales, is_scale_transposed=self.has_col_major_scales,
) )
# result, scale # result, scale
return at[1], at[2] return at[1], at[2]
scale = self.quant_matcher.empty_f32(1, 1)
pm.register_replacement( pm.register_replacement(
pattern, pattern,
replacement, replacement,
self.rmsnorm_matcher.inputs(), self.rmsnorm_matcher.inputs() + [scale],
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
) )
...@@ -532,13 +602,15 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): ...@@ -532,13 +602,15 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]: for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
for has_col_major_scales in [True, False]: for has_col_major_scales in [True, False]:
for is_e8m0 in [True, False]: for is_e8m0 in [True, False]:
for is_tma_aligned in [False, True]:
# Fuse fused_add_rms_norm + fp8 group quant # Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern( FusedAddRMSNormGroupQuantPattern(
epsilon, epsilon,
FP8_DTYPE, FP8_DTYPE,
group_shape=group_shape, group_shape=group_shape,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0, is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns) ).register(self.patterns)
# Fuse rms_norm + fp8 group quant # Fuse rms_norm + fp8 group quant
...@@ -546,8 +618,9 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): ...@@ -546,8 +618,9 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
epsilon, epsilon,
FP8_DTYPE, FP8_DTYPE,
group_shape=group_shape, group_shape=group_shape,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0, is_e8m0=is_e8m0,
has_col_major_scales=has_col_major_scales,
is_tma_aligned=is_tma_aligned,
).register(self.patterns) ).register(self.patterns)
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
......
...@@ -924,7 +924,16 @@ def per_token_group_quant_fp8( ...@@ -924,7 +924,16 @@ def per_token_group_quant_fp8(
# TODO(bnell): this causes some fp8 moe test to fail. # TODO(bnell): this causes some fp8 moe test to fail.
if current_platform.is_cuda() and x.is_contiguous(): if current_platform.is_cuda() and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant( torch.ops._C.per_token_group_fp8_quant(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0 x,
x_q,
x_s,
group_size,
eps,
fp8_min,
fp8_max,
use_ue8m0,
column_major_scales,
tma_aligned_scales,
) )
return x_q, x_s return x_q, x_s
......
...@@ -349,7 +349,7 @@ def _align(x: int, y: int) -> int: ...@@ -349,7 +349,7 @@ def _align(x: int, y: int) -> int:
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/csrc/utils/math.hpp#L19 # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/csrc/utils/math.hpp#L19
def get_tma_aligned_size(x: int, element_size: int): def get_tma_aligned_size(x: int, element_size: int) -> int:
return _align(x, 16 // element_size) return _align(x, 16 // element_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment