"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "66a5279a9422962b1cff3ad0e5747e8903ae067b"
Unverified Commit dd865bef authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[Hotfix] solve fp8 w8a8 ci test fail (#4531)

parent d373a48c
...@@ -799,9 +799,18 @@ class Fp8MoEMethod: ...@@ -799,9 +799,18 @@ class Fp8MoEMethod:
layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id], layer.w13_weight_scale[expert_id][shard_id],
) )
layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( if _is_cuda:
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) (
) layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = vllm_ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
start += shard_size start += shard_size
layer.w13_weight_scale = torch.nn.Parameter( layer.w13_weight_scale = torch.nn.Parameter(
......
...@@ -15,6 +15,13 @@ from sglang.srt.utils import ( ...@@ -15,6 +15,13 @@ from sglang.srt.utils import (
is_hip, is_hip,
) )
try:
import vllm
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL") use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
_is_hip = is_hip() _is_hip = is_hip()
...@@ -27,13 +34,8 @@ if _is_cuda: ...@@ -27,13 +34,8 @@ if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
if use_vllm_cutlass_w8a8_fp8_kernel: if use_vllm_cutlass_w8a8_fp8_kernel and VLLM_AVAILABLE:
try: from vllm import _custom_ops as ops
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
else: else:
from sgl_kernel import fp8_scaled_mm from sgl_kernel import fp8_scaled_mm
...@@ -253,68 +255,69 @@ def apply_fp8_linear( ...@@ -253,68 +255,69 @@ def apply_fp8_linear(
# torch.scaled_mm supports per tensor weights + activations only # torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token # so fallback to naive if per channel or per token
per_tensor_weights = weight_scale.numel() == 1 else:
per_tensor_activations = x_scale.numel() == 1 per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = x_scale.numel() == 1
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ if per_tensor_weights and per_tensor_activations:
output = torch._scaled_mm( # Fused GEMM_DQ
qinput, output = torch._scaled_mm(
weight, qinput,
out_dtype=input.dtype, weight,
scale_a=x_scale, out_dtype=input.dtype,
scale_b=weight_scale, scale_a=x_scale,
bias=bias, scale_b=weight_scale,
) bias=bias,
# A fix for discrepancy in scaled_mm which returns tuple )
# for torch < 2.5 and a single value in torch >= 2.5 # A fix for discrepancy in scaled_mm which returns tuple
if type(output) is tuple and len(output) == 2: # for torch < 2.5 and a single value in torch >= 2.5
output = output[0] if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
else: else:
# Fallback for channelwise case, where we use unfused DQ # Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm # due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following: # Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias # C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations # This is equivalent to dequantizing the weights and activations
# before applying a GEMM. # before applying a GEMM.
# #
# In order to compute quantized operands, a quantized kernel # In order to compute quantized operands, a quantized kernel
# will rewrite the above like so: # will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias # C = s_w * s_x * (X * W) + bias
# #
# For the scaled_mm fallback case, we break this down, since it # For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector. # does not support s_w being a vector.
# Making sure the dummy tensor is on the same device as the weight # Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device: if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
# GEMM # GEMM
# This computes C = (X * W). # This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place # Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm( output = torch._scaled_mm(
qinput, qinput,
weight, weight,
scale_a=TORCH_DEVICE_IDENTITY, scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY, scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32, out_dtype=torch.float32,
) )
# A fix for discrepancy in scaled_mm which returns tuple # A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5 # for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2: if type(output) is tuple and len(output) == 2:
output = output[0] output = output[0]
# Unpad (undo num_token_padding) # Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0]) output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
# DQ # DQ
# C = sw * sx * (X * W) + bias # C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t() output = output * x_scale * weight_scale.t()
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output.to(dtype=input.dtype).view(*output_shape) return output.to(dtype=input.dtype).view(*output_shape)
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.utils import scalar_types
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
...@@ -133,11 +132,16 @@ class GPTQConfig(QuantizationConfig): ...@@ -133,11 +132,16 @@ class GPTQConfig(QuantizationConfig):
class GPTQMarlinConfig(QuantizationConfig): class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin""" """Config class for GPTQ Marlin"""
# (num_bits, is_sym) -> quant_type if VLLM_AVAILABLE:
TYPE_MAP = { from vllm.scalar_type import scalar_types
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128, # (num_bits, is_sym) -> quant_type
} TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
else:
raise ImportError("vllm is not installed")
def __init__( def __init__(
self, self,
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/scalar_type.py
import functools
import struct
from dataclasses import dataclass
from enum import Enum
from types import MappingProxyType from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple, Union from typing import List, Mapping, Tuple, Union
import torch import torch
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
def is_layer_skipped( def is_layer_skipped(
prefix: str, prefix: str,
...@@ -102,341 +106,12 @@ def requantize_with_max_scale( ...@@ -102,341 +106,12 @@ def requantize_with_max_scale(
for idx, logical_width in enumerate(logical_widths): for idx, logical_width in enumerate(logical_widths):
end = start + logical_width end = start + logical_width
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
weight[start:end, :], _ = ops.scaled_fp8_quant(weight_dq, max_w_scale) if _is_cuda:
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
else:
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
weight_dq, max_w_scale
)
start = end start = end
return max_w_scale, weight return max_w_scale, weight
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
NONE = 0 # nans are not supported
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
# This ScalarType class is a parallel implementation of the C++ ScalarType
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
# in sync until the inductor fully supports custom C++ classes.
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent: int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
signed: bool
"If the type is signed (i.e. has a sign bit)"
bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
_finite_values_only: bool = False
"""
Private: if infs are supported, used `has_infs()` instead.
"""
nan_repr: NanRepr = NanRepr.IEEE_754
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
def _floating_point_max_int(self) -> int:
assert (
self.mantissa <= 52 and self.exponent <= 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_mantissa = (1 << self.mantissa) - 1
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
max_mantissa = max_mantissa - 1
max_exponent = (1 << self.exponent) - 2
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
assert (
self.exponent < 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_exponent = max_exponent + 1
# adjust the exponent to match that of a double
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
# e is the exponent bits), there is some precedent for non-standard
# biases, example `float8_e4m3b11fnuz` here:
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
# complication we are just assuming the standard exponent bias until
# there is a need to support non-standard biases
exponent_bias = (1 << (self.exponent - 1)) - 1
exponent_bias_double = (1 << 10) - 1 # double e = 11
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
# shift the mantissa and exponent into the proper positions for an
# IEEE double and bitwise-or them together.
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
def _floating_point_max(self) -> float:
double_raw = self._floating_point_max_int()
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
def _raw_max(self) -> Union[int, float]:
if self.is_floating_point():
return self._floating_point_max()
else:
assert (
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
), "Cannot represent max as an int"
return (1 << self.mantissa) - 1
def _raw_min(self) -> Union[int, float]:
if self.is_floating_point():
assert (
self.is_signed()
), "We currently assume all floating point types are signed"
sign_bit_double = 1 << 63
max_raw = self._floating_point_max_int()
min_raw = max_raw | sign_bit_double
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
else:
assert (
not self.is_signed() or self.size_bits <= 64
), "Cannot represent min as a int64_t"
if self.is_signed():
return -(1 << (self.size_bits - 1))
else:
return 0
@functools.cached_property
def id(self) -> int:
"""
Convert the ScalarType to an int which can be passed to pytorch custom
ops. This layout of the int must be kept in sync with the C++
ScalarType's from_id method.
"""
val = 0
offset = 0
def or_and_advance(member, bit_width):
nonlocal val
nonlocal offset
bit_mask = (1 << bit_width) - 1
val = val | (int(member) & bit_mask) << offset
offset = offset + bit_width
or_and_advance(self.exponent, 8)
or_and_advance(self.mantissa, 8)
or_and_advance(self.signed, 1)
or_and_advance(self.bias, 32)
or_and_advance(self._finite_values_only, 1)
or_and_advance(self.nan_repr.value, 8)
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
return val
@property
def size_bits(self) -> int:
return self.exponent + self.mantissa + int(self.signed)
def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_min() - self.bias
def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_max() - self.bias
def is_signed(self) -> bool:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
return self.signed
def is_floating_point(self) -> bool:
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self) -> bool:
"If the type is an integer type"
return self.exponent == 0
def has_bias(self) -> bool:
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self) -> bool:
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self) -> bool:
return self.nan_repr != NanRepr.NONE.value
def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only
def __str__(self) -> str:
"""
naming generally follows: https://github.com/jax-ml/ml_dtypes
for floating point types (leading f) the scheme is:
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
flags:
- no-flags: means it follows IEEE 754 conventions
- f: means finite values only (no infinities)
- n: means nans are supported (non-standard encoding)
for integer types the scheme is:
`[u]int<size_bits>[b<bias>]`
- if bias is not present it means its zero
"""
if self.is_floating_point():
ret = (
"float"
+ str(self.size_bits)
+ "_e"
+ str(self.exponent)
+ "m"
+ str(self.mantissa)
)
if not self.is_ieee_754():
if self._finite_values_only:
ret = ret + "f"
if self.nan_repr != NanRepr.NONE:
ret = ret + "n"
return ret
else:
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
if self.has_bias():
ret = ret + "b" + str(self.bias)
return ret
def __repr__(self) -> str:
return "ScalarType." + self.__str__()
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def __len__(self) -> int:
raise TypeError
#
# Convenience Constructors
#
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
"Create a signed integer scalar type (size_bits includes sign-bit)."
ret = cls(0, size_bits - 1, True, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
"""Create a unsigned integer scalar type."""
ret = cls(0, size_bits, False, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
assert mantissa > 0 and exponent > 0
ret = cls(exponent, mantissa, True, 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_(
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
) -> "ScalarType":
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
assert mantissa > 0 and exponent > 0
assert nan_repr != NanRepr.IEEE_754, (
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions"
)
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
ret.id # noqa B018: make sure the id is cached
return ret
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
# flags:
# - no-flags: means it follows IEEE 754 conventions
# - f: means finite values only (no infinities)
# - n: means nans are supported (non-standard encoding)
# for integer types the scheme is:
# `[u]int<size_bits>[b<bias>]`
# - if bias is not present it means its zero
class scalar_types:
int4 = ScalarType.int_(4, None)
uint4 = ScalarType.uint(4, None)
int8 = ScalarType.int_(8, None)
uint8 = ScalarType.uint(8, None)
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE)
# "gptq" types
uint2b2 = ScalarType.uint(2, 2)
uint3b4 = ScalarType.uint(3, 4)
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)
# colloquial names
bfloat16 = float16_e8m7
float16 = float16_e5m10
...@@ -27,3 +27,5 @@ pip install cuda-python nvidia-cuda-nvrtc-cu12 ...@@ -27,3 +27,5 @@ pip install cuda-python nvidia-cuda-nvrtc-cu12
# For DeepSeek-VL2 # For DeepSeek-VL2
pip install timm pip install timm
pip install sgl-kernel==0.0.5.post3 --force-reinstall
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment