"vllm/vscode:/vscode.git/clone" did not exist on "9659bc7f271ec640da780b5ca739e261764b954b"
Unverified Commit b5f8c309 authored by Maral's avatar Maral Committed by GitHub
Browse files

[W8A8 Block Linear Refactor][1/N] Keep all quantization types into `QuantFP8` class. (#33047)


Signed-off-by: default avatarmaral <maralbahari.98@gmail.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 21997f45
...@@ -1570,7 +1570,7 @@ class rocm_aiter_ops: ...@@ -1570,7 +1570,7 @@ class rocm_aiter_ops:
def group_fp8_quant( def group_fp8_quant(
input_2d: torch.Tensor, input_2d: torch.Tensor,
group_size: int = 128, group_size: int = 128,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert group_size == 128, "Group size must be 128" assert group_size == 128, "Group size must be 128"
return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size) return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size)
......
...@@ -14,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -14,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
prep_scale_for_group_broadcast, prep_scale_for_group_broadcast,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
_FP8_DTYPE = current_platform.fp8_dtype() _FP8_DTYPE = current_platform.fp8_dtype()
_FP8_MIN, _FP8_MAX = get_fp8_min_max() _FP8_MIN, _FP8_MAX = get_fp8_min_max()
...@@ -59,7 +64,8 @@ class QuantFP8(CustomOp): ...@@ -59,7 +64,8 @@ class QuantFP8(CustomOp):
self.num_token_padding = num_token_padding self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales self.column_major_scales = column_major_scales
self.tma_aligned_scales = tma_aligned_scales self.tma_aligned_scales = tma_aligned_scales
self.use_ue8m0 = use_ue8m0 self.use_ue8m0 = is_deep_gemm_e8m0_used() if use_ue8m0 is None else use_ue8m0
self.use_deep_gemm_supported = is_deep_gemm_supported()
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled() self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled()
...@@ -79,10 +85,23 @@ class QuantFP8(CustomOp): ...@@ -79,10 +85,23 @@ class QuantFP8(CustomOp):
x: torch.Tensor, x: torch.Tensor,
scale: torch.Tensor | None = None, scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None, scale_ub: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.quantization.utils import fp8_utils
if (
self.is_group_quant
and self.use_deep_gemm_supported
and (DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0)
):
return fp8_utils.per_token_group_quant_fp8_packed_for_deepgemm(
x,
group_size=self.group_size,
use_ue8m0=True,
)
if self.is_group_quant and not self.static: if self.is_group_quant and not self.static:
assert scale is None, "Dynamic group quantization does not use scale" assert scale is None, "Dynamic group quantization does not use scale"
from vllm.model_executor.layers.quantization.utils import fp8_utils
return fp8_utils.per_token_group_quant_fp8( return fp8_utils.per_token_group_quant_fp8(
x, x,
...@@ -116,25 +135,34 @@ class QuantFP8(CustomOp): ...@@ -116,25 +135,34 @@ class QuantFP8(CustomOp):
x: torch.Tensor, x: torch.Tensor,
scale: torch.Tensor | None = None, scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None, scale_ub: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
use_aiter_quant = ( use_triton = kwargs.get("use_triton", False)
not self.is_group_quant if self.is_group_quant and use_triton:
and self.use_aiter assert scale is None, "Dynamic group quantization does not use scale"
and scale_ub is None
and x.is_contiguous() return torch.ops.vllm.triton_per_token_group_quant_fp8(x, self.group_size)
)
use_aiter_quant = self.use_aiter and scale_ub is None and x.is_contiguous()
use_aiter_per_tensor_quant = ( use_aiter_per_tensor_quant = (
use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR use_aiter_quant and self.group_shape.is_per_tensor()
)
use_aiter_per_token_quant = (
use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN
) )
use_aiter_per_token_quant = use_aiter_quant and self.group_shape.is_per_token()
use_aiter_per_group_quant = use_aiter_quant and self.group_shape.is_per_group()
if use_aiter_per_group_quant:
return rocm_aiter_ops.group_fp8_quant(x, self.group_size)
if use_aiter_per_tensor_quant: if use_aiter_per_tensor_quant:
return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale) return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale)
if use_aiter_per_token_quant: if use_aiter_per_token_quant:
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale) return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
# Fallback to native implementation for group quantization.
if self.is_group_quant:
assert scale is None, "Dynamic group quantization does not use scale"
return self._quantize_group_native(x)
# Fallback to CUDA implementation # Fallback to CUDA implementation
return self.forward_cuda(x, scale, scale_ub) return self.forward_cuda(x, scale, scale_ub)
......
...@@ -33,7 +33,6 @@ from vllm.model_executor.utils import replace_parameter, set_weight_attrs ...@@ -33,7 +33,6 @@ from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt, fp8_gemm_nt,
get_tma_aligned_size, get_tma_aligned_size,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
...@@ -426,13 +425,6 @@ class W8A8BlockFp8LinearOp: ...@@ -426,13 +425,6 @@ class W8A8BlockFp8LinearOp:
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d,
group_size=self.act_quant_group_shape.col,
use_ue8m0=True,
)
else:
assert self.deepgemm_input_quant_op is not None assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d) q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
output = torch.empty( output = torch.empty(
...@@ -497,15 +489,8 @@ class W8A8BlockFp8LinearOp: ...@@ -497,15 +489,8 @@ class W8A8BlockFp8LinearOp:
if input_scale is not None: if input_scale is not None:
q_input = input_2d q_input = input_2d
elif use_triton:
q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
)
else: else:
q_input, input_scale = rocm_aiter_ops.group_fp8_quant( q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton)
input_2d, self.act_quant_group_shape.col
)
return gemm_a8w8_blockscale_op( return gemm_a8w8_blockscale_op(
q_input, q_input,
...@@ -572,7 +557,7 @@ class W8A8BlockFp8LinearOp: ...@@ -572,7 +557,7 @@ class W8A8BlockFp8LinearOp:
], ],
torch.Tensor, torch.Tensor,
], ],
QuantFP8 | None, QuantFP8,
]: ]:
if use_cutlass: if use_cutlass:
return self._run_cutlass, ( return self._run_cutlass, (
...@@ -584,7 +569,12 @@ class W8A8BlockFp8LinearOp: ...@@ -584,7 +569,12 @@ class W8A8BlockFp8LinearOp:
) )
) )
if use_aiter_and_is_supported: if use_aiter_and_is_supported:
return self._run_aiter, None return self._run_aiter, QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
)
return self._run_triton, ( return self._run_triton, (
QuantFP8( QuantFP8(
False, False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment