Unverified Commit 04e3ff69 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Support compressed tensors fp8w8a8 (#4743)

parent 45fdf1f7
name: VLLM Dependency Test
on:
push:
branches: [ main ]
paths:
- "python/pyproject.toml"
- "python/sglang/**"
- "test/**"
- "docs/**"
- "scripts/**"
pull_request:
branches: [ main ]
paths:
- "python/pyproject.toml"
- "python/sglang/**"
- "test/**"
- "docs/**"
- "scripts/**"
concurrency:
group: vllm-dependency-test-${{ github.ref }}
cancel-in-progress: true
jobs:
vllm-dependency-test:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false
runs-on: 1-gpu-runner
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
env:
FLASHINFER_REPO: 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python'
run: |
bash scripts/ci_install_dependency.sh
pip install "vllm>=0.6.4.post1,<=0.7.2"
- name: Run VLLM dependency tests
timeout-minutes: 60
run: |
cd test/srt
python3 run_suite.py --suite vllm_dependency_test --timeout-per-file 3600
...@@ -47,7 +47,6 @@ srt = [ ...@@ -47,7 +47,6 @@ srt = [
"sgl-kernel==0.0.5.post3", "sgl-kernel==0.0.5.post3",
"flashinfer_python==0.2.3", "flashinfer_python==0.2.3",
"torch==2.5.1", "torch==2.5.1",
"vllm>=0.6.4.post1,<=0.7.2",
"cuda-python", "cuda-python",
"outlines>=0.0.44,<=0.1.11", "outlines>=0.0.44,<=0.1.11",
] ]
......
...@@ -22,7 +22,11 @@ import torch ...@@ -22,7 +22,11 @@ import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.layers.quantization import (
BASE_QUANTIZATION_METHODS,
QUANTIZATION_METHODS,
VLLM_AVAILABLE,
)
from sglang.srt.utils import get_bool_env_var, is_hip from sglang.srt.utils import get_bool_env_var, is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -235,7 +239,12 @@ class ModelConfig: ...@@ -235,7 +239,12 @@ class ModelConfig:
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] # Select supported quantization methods based on vllm availability
if VLLM_AVAILABLE:
supported_quantization = [*QUANTIZATION_METHODS]
else:
supported_quantization = [*BASE_QUANTIZATION_METHODS]
rocm_supported_quantization = [ rocm_supported_quantization = [
"awq", "awq",
"gptq", "gptq",
...@@ -273,7 +282,11 @@ class ModelConfig: ...@@ -273,7 +282,11 @@ class ModelConfig:
quant_method = quant_cfg.get("quant_method", "").lower() quant_method = quant_cfg.get("quant_method", "").lower()
# Detect which checkpoint is it # Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items(): # Only iterate through currently available quantization methods
available_methods = (
QUANTIZATION_METHODS if VLLM_AVAILABLE else BASE_QUANTIZATION_METHODS
)
for _, method in available_methods.items():
quantization_override = method.override_quantization_method( quantization_override = method.override_quantization_method(
quant_cfg, self.quantization quant_cfg, self.quantization
) )
......
...@@ -1316,7 +1316,10 @@ vllm_get_world_group = None ...@@ -1316,7 +1316,10 @@ vllm_get_world_group = None
def monkey_patch_vllm_parallel_state(reverse: bool = False): def monkey_patch_vllm_parallel_state(reverse: bool = False):
import vllm.distributed.parallel_state as vllm_parrlel_state try:
import vllm.distributed.parallel_state as vllm_parrlel_state
except ImportError:
return
global vllm_get_pp_group, vllm_get_tp_group, vllm_get_world_group global vllm_get_pp_group, vllm_get_tp_group, vllm_get_world_group
if vllm_get_pp_group is None: if vllm_get_pp_group is None:
......
...@@ -23,6 +23,7 @@ from sglang.srt.layers.parameter import ( ...@@ -23,6 +23,7 @@ from sglang.srt.layers.parameter import (
PackedvLLMParameter, PackedvLLMParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
RowvLLMParameter, RowvLLMParameter,
_ColumnvLLMParameter,
) )
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
...@@ -423,8 +424,6 @@ class ColumnParallelLinear(LinearBase): ...@@ -423,8 +424,6 @@ class ColumnParallelLinear(LinearBase):
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
from sglang.srt.layers.parameter import _ColumnvLLMParameter
if isinstance(param, _ColumnvLLMParameter): if isinstance(param, _ColumnvLLMParameter):
param.load_column_parallel_weight( param.load_column_parallel_weight(
loaded_weight, loaded_weight,
...@@ -1247,7 +1246,7 @@ class RowParallelLinear(LinearBase): ...@@ -1247,7 +1246,7 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
if isinstance(param, BasevLLMParameter): if isinstance(param, RowvLLMParameter):
# This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py, # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
# It supports additional parameters like tp_rank and use_presharded_weights. # It supports additional parameters like tp_rank and use_presharded_weights.
param.load_row_parallel_weight( param.load_row_parallel_weight(
......
...@@ -8,7 +8,6 @@ from typing import Callable, Optional ...@@ -8,7 +8,6 @@ from typing import Callable, Optional
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -69,6 +68,8 @@ def moe_forward_native( ...@@ -69,6 +68,8 @@ def moe_forward_native(
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -305,6 +305,7 @@ class FusedMoE(torch.nn.Module): ...@@ -305,6 +305,7 @@ class FusedMoE(torch.nn.Module):
self.use_presharded_weights = use_presharded_weights self.use_presharded_weights = use_presharded_weights
self.inplace = inplace self.inplace = inplace
self.no_combine = no_combine self.no_combine = no_combine
self.local_num_experts = num_experts
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = ( self.quant_method: Optional[QuantizeMethodBase] = (
...@@ -629,8 +630,6 @@ class FusedMoE(torch.nn.Module): ...@@ -629,8 +630,6 @@ class FusedMoE(torch.nn.Module):
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
activation=self.activation, activation=self.activation,
inplace=self.inplace,
no_combine=self.no_combine,
) )
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
......
...@@ -17,11 +17,12 @@ from typing import Callable, Optional ...@@ -17,11 +17,12 @@ from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.utils import get_compiler_backend, is_cuda from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip()
from sglang.srt.managers.utils import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
expert_distribution_recorder = ExpertDistributionRecorder() expert_distribution_recorder = ExpertDistributionRecorder()
...@@ -53,10 +54,10 @@ def fused_topk( ...@@ -53,10 +54,10 @@ def fused_topk(
topk: int, topk: int,
renormalize: bool, renormalize: bool,
): ):
if _is_cuda: if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax from sgl_kernel import topk_softmax
else: else:
from vllm import _custom_ops as ops from vllm import _custom_ops as vllm_ops
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
...@@ -70,7 +71,7 @@ def fused_topk( ...@@ -70,7 +71,7 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device M, topk, dtype=torch.int32, device=hidden_states.device
) )
if _is_cuda: if _is_cuda or _is_hip:
topk_softmax( topk_softmax(
topk_weights, topk_weights,
topk_ids, topk_ids,
...@@ -78,7 +79,7 @@ def fused_topk( ...@@ -78,7 +79,7 @@ def fused_topk(
gating_output.float(), gating_output.float(),
) )
else: else:
ops.topk_softmax( vllm_ops.topk_softmax(
topk_weights, topk_weights,
topk_ids, topk_ids,
token_expert_indicies, token_expert_indicies,
......
...@@ -12,9 +12,6 @@ try: ...@@ -12,9 +12,6 @@ try:
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
...@@ -26,6 +23,8 @@ try: ...@@ -26,6 +23,8 @@ try:
from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
VLLM_AVAILABLE = True VLLM_AVAILABLE = True
except ImportError: except ImportError:
VLLM_AVAILABLE = False VLLM_AVAILABLE = False
...@@ -44,8 +43,10 @@ except ImportError: ...@@ -44,8 +43,10 @@ except ImportError:
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
...@@ -55,10 +56,9 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -55,10 +56,9 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"fp8": Fp8Config, "fp8": Fp8Config,
"blockwise_int8": BlockInt8Config, "blockwise_int8": BlockInt8Config,
"modelopt": ModelOptFp8Config, "modelopt": ModelOptFp8Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
"w8a8_int8": W8A8Int8Config, "w8a8_int8": W8A8Int8Config,
"w8a8_fp8": W8A8Fp8Config, "w8a8_fp8": W8A8Fp8Config,
"compressed-tensors": CompressedTensorsConfig,
} }
# Add vllm-dependent methods if available # Add vllm-dependent methods if available
...@@ -74,10 +74,11 @@ if VLLM_AVAILABLE: ...@@ -74,10 +74,11 @@ if VLLM_AVAILABLE:
"gguf": GGUFConfig, "gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin_24": GPTQMarlin24Config,
"awq_marlin": AWQMarlinConfig, "awq_marlin": AWQMarlinConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig, "bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig, "qqq": QQQConfig,
"experts_int8": ExpertsInt8Config, "experts_int8": ExpertsInt8Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
} }
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS) QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
......
...@@ -38,6 +38,11 @@ class QuantizeMethodBase(ABC): ...@@ -38,6 +38,11 @@ class QuantizeMethodBase(ABC):
class QuantizationConfig(ABC): class QuantizationConfig(ABC):
"""Base class for quantization configs.""" """Base class for quantization configs."""
def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: Dict[str, List[str]] = dict()
@abstractmethod @abstractmethod
def get_name(self) -> str: def get_name(self) -> str:
"""Name of the quantization method.""" """Name of the quantization method."""
......
# quantization compressed_tensors module
To support compressed_tensors format quantization models, we adapted https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors into SGLang.
For practical purposes, we have only applied the compressed_tensors format of `w8a8_fp8`. If you have requirements for other formats, you can submit an issue through this [link](https://github.com/sgl-project/sglang/issues).
# SPDX-License-Identifier: Apache-2.0
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsW8A8Fp8",
]
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Optional
import torch
__all__ = ["CompressedTensorsScheme"]
class CompressedTensorsScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by CompressedTensors.
"""
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise NotImplementedError
@abstractmethod
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
"""
raise NotImplementedError
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
import torch
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter
from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from sglang.srt.layers.quantization.fp8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
__all__ = ["CompressedTensorsW8A8Fp8"]
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.strategy == QuantizationStrategy.TENSOR:
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
if is_fp8_fnuz():
input_scale = getattr(layer, "input_scale", None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=max_w_scale, input_scale=input_scale
)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
if is_fp8_fnuz():
input_scale = getattr(layer, "input_scale", None)
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=input_scale,
)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)
else:
weight_scale = layer.weight_scale.data
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")
# INPUT SCALE
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
else:
layer.input_scale = None
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
# min requirement for fp8 kernels
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import re
from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional
from compressed_tensors import CompressionFormat
from torch.nn import Module
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
CompressionFormat.naive_quantized.value,
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
]
return format in _ACTIVATION_QUANTIZATION_FORMATS
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
) -> bool:
if layer_name is None:
return False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name = layer_name.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping and layer_name not in ignore:
shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer = None
for shard_name in shard_names:
should_ignore_shard = check_equal_or_regex_match(
layer_name=shard_name, targets=ignore
)
# If shard_idx=0, set layer ignore to match shard.
if should_ignore_layer is None:
should_ignore_layer = should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif should_ignore_shard != should_ignore_layer:
raise ValueError(
f"Found a different quantization schemes for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme."
)
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else:
should_ignore_layer = check_equal_or_regex_match(
layer_name=layer_name, targets=ignore
)
assert should_ignore_layer is not None
return should_ignore_layer
def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
"""
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def find_matched_target(
layer_name: Optional[str],
module: Module,
targets: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with with a different
scheme.
targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
Third, we try to map the layer_name to a list of fused module names.
*All* component module names must match in order for a match to be
successful. A successful match returns the first component target
:param layer_name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
:param fused_strategy: either "all" or "any". If using "all", fused
layers match if "all" of its components match
"""
if layer_name is None:
layer_name = ""
matched_target = (
_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets, True)
or _match_fused_layer(layer_name, targets, fused_mapping)
)
if matched_target is None:
raise ValueError(
f"Unable to find matching target for {layer_name} in the "
"compressed-tensors config."
)
return matched_target
def _find_first_match(
value: str, targets: Iterable[str], check_contains: bool = False
) -> Optional[str]:
"""
Returns first element of target that matches value either
exactly or as a regex after 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
:param value: string to compare the list of targets against
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
for target in targets:
if _is_equal_or_regex_match(value, target, check_contains=check_contains):
return target
return None
def _is_equal_or_regex_match(
value: str, target: str, check_contains: bool = False
) -> bool:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return True
elif check_contains:
if target.lower() in value.lower():
return True
elif target == value:
return True
return False
def _match_fused_layer(
layer_name: str,
target_layers: Iterable[str],
fused_mapping: Mapping[str, List[str]],
) -> Optional[str]:
"""
Match a fused layer name to its corresponding individual layer in
target_layers. Returns first value in fused_mapping which matches targets
Implements an "all" matching strategy where a fused layer matches iff
"all" of its components match
:param layer_name: layer name
:param target_layers: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
Examples:
layer_name = "model.layers.0.self_attn.qkv_proj"
target_layers = ["model.layers.0.self_attn.q_proj",
"model.layers.0.self_attn.k_proj",
"model.layers.0.self_attn.v_proj"]
"""
# find layer_name in mapping
fused = next((key for key in fused_mapping if layer_name.endswith(key)), None)
if fused is None:
return None
# expand path of unfused components
unfused_paths = [
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
]
# for each unfused component, find a match in targets
unfused_matches: List[Optional[str]] = []
for unfused in unfused_paths:
for target in target_layers:
if _is_equal_or_regex_match(unfused, target):
unfused_matches.append(target)
break
else:
unfused_matches.append(None)
return unfused_matches[0] if all(unfused_matches) else None
import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -18,6 +19,7 @@ from sglang.srt.utils import ( ...@@ -18,6 +19,7 @@ from sglang.srt.utils import (
try: try:
import vllm import vllm
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True VLLM_AVAILABLE = True
except ImportError: except ImportError:
...@@ -31,19 +33,29 @@ if _is_hip and get_bool_env_var("CK_MOE"): ...@@ -31,19 +33,29 @@ if _is_hip and get_bool_env_var("CK_MOE"):
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
if use_vllm_cutlass_w8a8_fp8_kernel and VLLM_AVAILABLE:
from vllm import _custom_ops as ops
else:
from sgl_kernel import fp8_scaled_mm
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
_TORCH_VERSION = torch.__version__.split("+")[0]
try:
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
except ValueError:
_TORCH_VERSION_TUPLE = (0, 0, 0)
# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time consuming.
USE_ROWWISE_TORCH_SCALED_MM = (
_is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
)
def cutlass_fp8_supported(): def cutlass_fp8_supported():
if not _is_cuda: if not _is_cuda:
...@@ -330,3 +342,223 @@ def apply_fp8_linear( ...@@ -330,3 +342,223 @@ def apply_fp8_linear(
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output.to(dtype=input.dtype).view(*output_shape) return output.to(dtype=input.dtype).view(*output_shape)
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
class Fp8LinearOp:
"""
This class executes a FP8 linear layer using cutlass if supported and
torch.scaled_mm otherwise.
It needs to be a class instead of a method so that config can be read
in the __init__ method, as reading config is not allowed inside forward.
"""
def __init__(
self,
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
use_per_token_if_dynamic: bool = False,
pad_output: Optional[bool] = None,
):
self.cutlass_fp8_supported = cutlass_fp8_supported
self.use_per_token_if_dynamic = use_per_token_if_dynamic
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
if pad_output is None:
enable_torch_compile = os.environ.get(
"SGLANG_ENABLE_TORCH_COMPILE", "0"
).lower() in ("1", "true", "yes")
pad_output = not enable_torch_compile
self.output_padding = 17 if pad_output else None
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
# TODO(luka) remove this parameter in favor of __init__
use_per_token_if_dynamic: Optional[bool] = None,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]
# TODO(luka) this is here because currently MLA only decides this
# during the forward method instead of in __init__.
if use_per_token_if_dynamic is None:
use_per_token_if_dynamic = self.use_per_token_if_dynamic
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
# for sgl-kernel fp8_scaled_mm, it support per channel W now
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
if _is_cuda:
qinput, x_scale = sgl_scaled_fp8_quant(
input_2d,
input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
else:
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
# Fused GEMM_DQ
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
)
else:
assert (
weight_scale.numel() == weight.shape[1]
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
output = fp8_scaled_mm(
qinput,
weight,
x_scale,
weight_scale,
out_dtype=input.dtype,
bias=bias,
)
return output.view(*output_shape)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else:
# Maybe apply padding to output, see comment in __init__
if _is_cuda:
qinput, x_scale = sgl_scaled_fp8_quant(
input_2d,
input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
if self.output_padding:
pad_size = max(self.output_padding - qinput.shape[0], 0)
if pad_size > 0:
qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
else:
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = x_scale.numel() == 1
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output = torch._scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
elif (
use_per_token_if_dynamic
and not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM
):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias,
)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = output.view(*output_shape)
return output
else:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
output = torch._scaled_mm(
qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
...@@ -15,6 +15,11 @@ else: ...@@ -15,6 +15,11 @@ else:
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
def is_fp8_fnuz() -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
def is_layer_skipped( def is_layer_skipped(
prefix: str, prefix: str,
ignored_layers: List[str], ignored_layers: List[str],
...@@ -120,3 +125,29 @@ def requantize_with_max_scale( ...@@ -120,3 +125,29 @@ def requantize_with_max_scale(
start = end start = end
return max_w_scale, weight return max_w_scale, weight
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_parameter(
mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]
) -> None:
old = getattr(mod, name)
if (
type(old) is type(new)
and old.dtype == new.dtype
and old.untyped_storage().nbytes() == new.untyped_storage().nbytes()
):
# If we can just update in-place to avoid re-registering
# can be faster if the underlying storage is the same
update_tensor_inplace(old, new)
else:
# Fallback re-register parameter, convert to Parameter if necessary
# this not only ensures we don't register a tensor as a parameter, but
# also ensures that all parameter subclasses get re-registered as
# parameters for `torch.compile` compatibility
if not isinstance(new, torch.nn.Parameter):
new = torch.nn.Parameter(new, requires_grad=False)
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment