Commit e89003dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-fth-fp8' into 'v0.9.2-dev'

nmz适配block和channel fp8

See merge request dcutoolkit/deeplearing/vllm!360
parents be18d0df db23fcac
......@@ -14,6 +14,7 @@ from vllm.utils import direct_register_custom_op
try:
from lmslim import quant_ops
from lmslim import quant_tools
from lmslim.layers.gemm.fp8_utils import per_token_quant_fp8
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try:
......@@ -1692,66 +1693,67 @@ def scaled_fp4_experts_quant(
# fp8
# def scaled_fp8_quant(
# input: torch.Tensor,
# scale: Optional[torch.Tensor] = None,
# num_token_padding: Optional[int] = None,
# scale_ub: Optional[torch.Tensor] = None,
# use_per_token_if_dynamic: bool = False,
# output: Optional[torch.Tensor] = None,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# """
# Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensors for downstream kernels that
# will benefit from padding.
# Args:
# input: The input tensor to be quantized to FP8
# scale: Optional scaling factor for the FP8 quantization
# scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
# num_token_padding: If specified, pad the first dimension
# of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
# Returns:
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# """
# # This code assumes batch_dim and num_tokens are flattened
# assert (input.ndim == 2)
# shape: Union[tuple[int, int], torch.Size] = input.shape
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = current_platform.fp8_dtype()
# if num_token_padding:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# if output is None:
# output = torch.empty(shape, device=input.device, dtype=out_dtype)
# else:
# assert num_token_padding is None, \
# "padding not supported if output passed in"
# assert output.dtype == out_dtype
# if scale is None:
# if use_per_token_if_dynamic:
# scale = torch.empty((shape[0], 1),
# device=input.device,
# dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
# else:
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# assert scale.numel() == 1, f"{scale.shape}"
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# return output, scale
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output, scale = per_token_quant_fp8(input.contiguous())
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
# gptq allspark
......
......@@ -654,6 +654,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe=False,
use_fused_gate: Optional[bool] = False,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
......
......@@ -140,7 +140,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None, **_,
) -> torch.Tensor:
return self.fp8_linear.apply(input=x,
weight=layer.weight,
......
......@@ -857,7 +857,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,**_,
) -> torch.Tensor:
if enable_eplb:
assert expert_load_view is not None
......
......@@ -11,7 +11,7 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from lmslim.layers.gemm.fp8_utils import triton_scaled_mm_fp8
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
......@@ -278,25 +278,27 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(qinput,
qinput = qinput.view(-1,qinput.shape[-1])
output = triton_scaled_mm_fp8(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
# if type(output) is tuple and len(output) == 2:
# output = output[0]
# # Unpad (undo num_token_padding)
# output = torch.narrow(output, 0, 0, input_2d.shape[0])
# x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
#
# # DQ
# # C = sw * sx * (X * W) + bias
# output = output * x_scale * scale_b.t()
# if bias is not None:
# output = output + bias
return output.view(*output_shape)
def dispatch_w8a8_scaled_mm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment