Unverified Commit 74e0ac1d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up `import vllm` in quantization/__init__.py (#4834)

parent ef9a378a
...@@ -4,19 +4,15 @@ on: ...@@ -4,19 +4,15 @@ on:
push: push:
branches: [ main ] branches: [ main ]
paths: paths:
- "python/pyproject.toml" - "python/**"
- "python/sglang/**"
- "test/**"
- "docs/**"
- "scripts/**" - "scripts/**"
- "test/**"
pull_request: pull_request:
branches: [ main ] branches: [ main ]
paths: paths:
- "python/pyproject.toml" - "python/**"
- "python/sglang/**"
- "test/**"
- "docs/**"
- "scripts/**" - "scripts/**"
- "test/**"
workflow_dispatch: workflow_dispatch:
inputs: inputs:
version: version:
......
...@@ -4,19 +4,15 @@ on: ...@@ -4,19 +4,15 @@ on:
push: push:
branches: [ main ] branches: [ main ]
paths: paths:
- "python/pyproject.toml" - "python/**"
- "python/sglang/**"
- "test/**"
- "docs/**"
- "scripts/**" - "scripts/**"
- "test/**"
pull_request: pull_request:
branches: [ main ] branches: [ main ]
paths: paths:
- "python/pyproject.toml" - "python/**"
- "python/sglang/**"
- "test/**"
- "docs/**"
- "scripts/**" - "scripts/**"
- "test/**"
concurrency: concurrency:
group: vllm-dependency-test-${{ github.ref }} group: vllm-dependency-test-${{ github.ref }}
......
...@@ -17,6 +17,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle ...@@ -17,6 +17,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle
[project.optional-dependencies] [project.optional-dependencies]
runtime_common = [ runtime_common = [
"compressed-tensors",
"datasets", "datasets",
"decord", "decord",
"fastapi", "fastapi",
...@@ -56,7 +57,12 @@ srt = [ ...@@ -56,7 +57,12 @@ srt = [
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20250114, not from public vllm whl # => base docker rocm/vllm-dev:20250114, not from public vllm whl
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"] srt_hip = [
"sglang[runtime_common]",
"torch",
"vllm==0.6.7.dev2",
"outlines==0.1.11"
]
# xpu is not enabled in public vllm and torch whl, # xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm # need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
......
...@@ -22,11 +22,7 @@ import torch ...@@ -22,11 +22,7 @@ import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.layers.quantization import ( from sglang.srt.layers.quantization import QUANTIZATION_METHODS
BASE_QUANTIZATION_METHODS,
QUANTIZATION_METHODS,
VLLM_AVAILABLE,
)
from sglang.srt.utils import get_bool_env_var, is_hip from sglang.srt.utils import get_bool_env_var, is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -239,12 +235,7 @@ class ModelConfig: ...@@ -239,12 +235,7 @@ class ModelConfig:
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
# Select supported quantization methods based on vllm availability
if VLLM_AVAILABLE:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
else:
supported_quantization = [*BASE_QUANTIZATION_METHODS]
rocm_supported_quantization = [ rocm_supported_quantization = [
"awq", "awq",
"gptq", "gptq",
...@@ -282,11 +273,7 @@ class ModelConfig: ...@@ -282,11 +273,7 @@ class ModelConfig:
quant_method = quant_cfg.get("quant_method", "").lower() quant_method = quant_cfg.get("quant_method", "").lower()
# Detect which checkpoint is it # Detect which checkpoint is it
# Only iterate through currently available quantization methods for _, method in QUANTIZATION_METHODS.items():
available_methods = (
QUANTIZATION_METHODS if VLLM_AVAILABLE else BASE_QUANTIZATION_METHODS
)
for _, method in available_methods.items():
quantization_override = method.override_quantization_method( quantization_override = method.override_quantization_method(
quant_cfg, self.quantization quant_cfg, self.quantization
) )
......
...@@ -17,12 +17,12 @@ from typing import Callable, Optional ...@@ -17,12 +17,12 @@ from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
expert_distribution_recorder = ExpertDistributionRecorder() expert_distribution_recorder = ExpertDistributionRecorder()
......
...@@ -9,12 +9,24 @@ import torch ...@@ -9,12 +9,24 @@ import torch
try: try:
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMoEMethod,
)
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsW8A8Fp8MoEMethod,
CompressedTensorsWNA16MoEMethod,
)
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config, GPTQMarlin24Config,
) )
...@@ -22,24 +34,24 @@ try: ...@@ -22,24 +34,24 @@ try:
from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
VLLM_AVAILABLE = True VLLM_AVAILABLE = True
except ImportError: except ImportError:
VLLM_AVAILABLE = False VLLM_AVAILABLE = False
# Define empty classes as placeholders when vllm is not available # Define empty classes as placeholders when vllm is not available
class DummyConfig: class DummyConfig:
pass def override_quantization_method(self, *args, **kwargs):
return None
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = ( AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
DummyConfig DeepSpeedFPConfig
) ) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = ( MarlinConfig
GPTQMarlin24Config ) = QQQConfig = Int8TpuConfig = DummyConfig
) = DummyConfig
MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
...@@ -47,9 +59,14 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ...@@ -47,9 +59,14 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
CompressedTensorsConfig, CompressedTensorsConfig,
) )
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
)
# Base quantization methods that don't depend on vllm # Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...@@ -61,10 +78,8 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -61,10 +78,8 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
} }
# Add vllm-dependent methods if available # VLLM-dependent quantization methods
QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy() VLLM_QUANTIZATION_METHODS = {
if VLLM_AVAILABLE:
VLLM_QUANTIZATION_METHODS = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"awq": AWQConfig, "awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig, "deepspeedfp": DeepSpeedFPConfig,
...@@ -79,8 +94,9 @@ if VLLM_AVAILABLE: ...@@ -79,8 +94,9 @@ if VLLM_AVAILABLE:
"experts_int8": ExpertsInt8Config, "experts_int8": ExpertsInt8Config,
"gptq_marlin": GPTQMarlinConfig, "gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
} }
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...@@ -89,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -89,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
f"Invalid quantization method: {quantization}. " f"Invalid quantization method: {quantization}. "
f"Available methods: {list(QUANTIZATION_METHODS.keys())}" f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
) )
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
raise ValueError(
f"{quantization} quantization requires some operators from vllm. "
"Pleaes install vllm by `pip install vllm==0.7.2`"
)
return QUANTIZATION_METHODS[quantization] return QUANTIZATION_METHODS[quantization]
...@@ -153,13 +175,6 @@ def get_linear_quant_method( ...@@ -153,13 +175,6 @@ def get_linear_quant_method(
prefix: str, prefix: str,
linear_method_cls: type, linear_method_cls: type,
): ):
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
)
cloned_config = deepcopy(config) cloned_config = deepcopy(config)
parallel_lm_head_quantized = ( parallel_lm_head_quantized = (
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
...@@ -186,18 +201,6 @@ def get_linear_quant_method( ...@@ -186,18 +201,6 @@ def get_linear_quant_method(
def gptq_get_quant_method(self, layer, prefix): def gptq_get_quant_method(self, layer, prefix):
if not VLLM_AVAILABLE:
return None
try:
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self) return GPTQMarlinMoEMethod(self)
...@@ -209,8 +212,6 @@ def gptq_get_quant_method(self, layer, prefix): ...@@ -209,8 +212,6 @@ def gptq_get_quant_method(self, layer, prefix):
return get_linear_quant_method( return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
) )
except ImportError:
pass
return None return None
...@@ -229,7 +230,6 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False): ...@@ -229,7 +230,6 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
builtins.isinstance = original_isinstance builtins.isinstance = original_isinstance
return return
try:
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -237,9 +237,7 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False): ...@@ -237,9 +237,7 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
) )
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import ( from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
FusedMoE as PatchedFusedMoE,
)
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding as PatchedVocabParallelEmbedding, VocabParallelEmbedding as PatchedVocabParallelEmbedding,
) )
...@@ -254,8 +252,6 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False): ...@@ -254,8 +252,6 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
return original_isinstance(obj, classinfo) return original_isinstance(obj, classinfo)
builtins.isinstance = patched_isinstance builtins.isinstance = patched_isinstance
except ImportError:
return
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
...@@ -263,10 +259,6 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): ...@@ -263,10 +259,6 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
Monkey patch the apply function of vllm's FusedMoEMethodBase. Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert sglang arguments to vllm arguments. Convert sglang arguments to vllm arguments.
""" """
if not VLLM_AVAILABLE:
return
try:
original_apply = class_obj.apply original_apply = class_obj.apply
sig = inspect.signature(original_apply) sig = inspect.signature(original_apply)
param_names = list(sig.parameters.keys()) param_names = list(sig.parameters.keys())
...@@ -312,25 +304,10 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): ...@@ -312,25 +304,10 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
return original_apply(**kwargs) return original_apply(**kwargs)
setattr(class_obj, "apply", new_apply) setattr(class_obj, "apply", new_apply)
except (ImportError, AttributeError):
return
def monkey_patch_quant_configs(): def monkey_patch_quant_configs():
"""Apply all monkey patches in one place.""" """Apply all monkey patches in one place."""
if not VLLM_AVAILABLE:
return
try:
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsW8A8Fp8MoEMethod,
CompressedTensorsWNA16MoEMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinMoEMethod,
)
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method) setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
...@@ -338,16 +315,8 @@ def monkey_patch_quant_configs(): ...@@ -338,16 +315,8 @@ def monkey_patch_quant_configs():
monkey_patch_moe_apply(GPTQMarlinMoEMethod) monkey_patch_moe_apply(GPTQMarlinMoEMethod)
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod) monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod) monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
except ImportError:
return
# Only apply monkey patches if vllm is available # Only apply monkey patches if vllm is available
if VLLM_AVAILABLE: if VLLM_AVAILABLE:
monkey_patch_quant_configs() monkey_patch_quant_configs()
__all__ = [
"get_quantization_config",
"QUANTIZATION_METHODS",
]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
import torch import torch
from sgl_kernel import awq_dequantize from sgl_kernel import awq_dequantize
......
...@@ -24,6 +24,7 @@ import triton.language as tl ...@@ -24,6 +24,7 @@ import triton.language as tl
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var,
get_device_core_count, get_device_core_count,
get_device_name, get_device_name,
get_device_sm, get_device_sm,
...@@ -43,7 +44,7 @@ if _is_cuda: ...@@ -43,7 +44,7 @@ if _is_cuda:
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
sm_version = get_device_sm() sm_version = get_device_sm()
if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")): if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
_enable_jit_deepgemm = True _enable_jit_deepgemm = True
......
...@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda ...@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
try: try:
import vllm from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
)
from vllm.scalar_type import scalar_types
VLLM_AVAILABLE = True VLLM_AVAILABLE = True
except ImportError: except ImportError:
VLLM_AVAILABLE = False VLLM_AVAILABLE = False
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
class scalar_types:
uint4b8 = "uint4b8"
uint8b128 = "uint8b128"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig): ...@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["GPTQLinearMethod"]: ) -> Optional[GPTQLinearMethod]:
if not VLLM_AVAILABLE: # Delay the import to avoid circular dependency
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from sglang.srt.layers.quantization import get_linear_quant_method from sglang.srt.layers.quantization import get_linear_quant_method
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
...@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig): ...@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
class GPTQMarlinConfig(QuantizationConfig): class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin""" """Config class for GPTQ Marlin"""
if VLLM_AVAILABLE:
from vllm.scalar_type import scalar_types
# (num_bits, is_sym) -> quant_type # (num_bits, is_sym) -> quant_type
TYPE_MAP = { TYPE_MAP = {
(4, True): scalar_types.uint4b8, (4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128, (8, True): scalar_types.uint8b128,
} }
else:
raise ImportError("vllm is not installed")
def __init__( def __init__(
self, self,
...@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}" "Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
) )
# (num_bits, is_sym) -> quant_type
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
if not VLLM_AVAILABLE: # Delay the import to avoid circular dependency
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import get_linear_quant_method from sglang.srt.layers.quantization import get_linear_quant_method
...@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod @classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
if not VLLM_AVAILABLE:
return False
quant_method = quant_config.get("quant_method", "").lower() quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits") num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size") group_size = quant_config.get("group_size")
sym = quant_config.get("sym") sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act") desc_act = quant_config.get("desc_act")
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
)
if not _is_cuda: if not _is_cuda:
return False return False
...@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig): ...@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["MarlinLinearMethod"]: ) -> Optional[MarlinLinearMethod]:
if not VLLM_AVAILABLE: # Delay the import to avoid circular dependency
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
# Delay import to avoid circular dependency
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
if isinstance(layer, LinearBase) or ( if isinstance(layer, LinearBase) or (
......
...@@ -53,8 +53,6 @@ class TpModelWorker: ...@@ -53,8 +53,6 @@ class TpModelWorker:
req_to_token_pool: Optional[ReqToTokenPool] = None, req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None, token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
): ):
self.worker = self
# Parse args # Parse args
self.tp_rank = tp_rank self.tp_rank = tp_rank
...@@ -134,6 +132,9 @@ class TpModelWorker: ...@@ -134,6 +132,9 @@ class TpModelWorker:
)[0] )[0]
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
# A reference make this class has the same member as TpModelWorkerClient
self.worker = self
def get_worker_info(self): def get_worker_info(self):
return ( return (
self.max_total_num_tokens, self.max_total_num_tokens,
......
...@@ -73,7 +73,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder ...@@ -73,7 +73,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip from sglang.srt.utils import add_prefix, is_cuda, is_hip
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
......
#!/bin/bash #!/bin/bash
set -euxo pipefail
# Install the dependency in CI. # Install the dependency in CI.
set -euxo pipefail
# Use repo from environment variables, passed from GitHub Actions
# Use repo from environment variable, passed from GitHub Actions
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}" FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}"
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
...@@ -17,17 +15,12 @@ pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2 ...@@ -17,17 +15,12 @@ pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2
rm -rf /root/.cache/flashinfer rm -rf /root/.cache/flashinfer
# Force reinstall flashinfer and torch_memory_saver # Force reinstall flashinfer and torch_memory_saver
pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps
pip install sgl-kernel==0.0.5.post3 --force-reinstall
pip install torch_memory_saver --force-reinstall pip install torch_memory_saver
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets timm
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets
# For compling xgrammar kernels # For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12 pip install cuda-python nvidia-cuda-nvrtc-cu12
# For DeepSeek-VL2
pip install timm
pip install sgl-kernel==0.0.5.post3 --force-reinstall
pip uninstall vllm -y || true pip uninstall vllm -y || true
...@@ -45,7 +45,7 @@ class TestEAGLEEngine(CustomTestCase): ...@@ -45,7 +45,7 @@ class TestEAGLEEngine(CustomTestCase):
"mem_fraction_static": 0.7, "mem_fraction_static": 0.7,
"cuda_graph_max_bs": 4, "cuda_graph_max_bs": 4,
} }
NUM_CONFIGS = 3 NUM_CONFIGS = 2
def setUp(self): def setUp(self):
self.prompt = "Today is a sunny day and I like" self.prompt = "Today is a sunny day and I like"
...@@ -61,8 +61,6 @@ class TestEAGLEEngine(CustomTestCase): ...@@ -61,8 +61,6 @@ class TestEAGLEEngine(CustomTestCase):
configs = [ configs = [
# Basic config # Basic config
self.BASE_CONFIG, self.BASE_CONFIG,
# Disable cuda graph
{**self.BASE_CONFIG, "disable_cuda_graph": True},
# Chunked prefill # Chunked prefill
{**self.BASE_CONFIG, "chunked_prefill_size": 4}, {**self.BASE_CONFIG, "chunked_prefill_size": 4},
] ]
......
...@@ -28,7 +28,7 @@ class TestTritonAttnBackend(CustomTestCase): ...@@ -28,7 +28,7 @@ class TestTritonAttnBackend(CustomTestCase):
"triton", "triton",
"--enable-torch-compile", "--enable-torch-compile",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
16, 4,
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment