Unverified Commit 74e0ac1d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up `import vllm` in quantization/__init__.py (#4834)

parent ef9a378a
......@@ -4,19 +4,15 @@ on:
push:
branches: [ main ]
paths:
- "python/pyproject.toml"
- "python/sglang/**"
- "test/**"
- "docs/**"
- "python/**"
- "scripts/**"
- "test/**"
pull_request:
branches: [ main ]
paths:
- "python/pyproject.toml"
- "python/sglang/**"
- "test/**"
- "docs/**"
- "python/**"
- "scripts/**"
- "test/**"
workflow_dispatch:
inputs:
version:
......
......@@ -4,19 +4,15 @@ on:
push:
branches: [ main ]
paths:
- "python/pyproject.toml"
- "python/sglang/**"
- "test/**"
- "docs/**"
- "python/**"
- "scripts/**"
- "test/**"
pull_request:
branches: [ main ]
paths:
- "python/pyproject.toml"
- "python/sglang/**"
- "test/**"
- "docs/**"
- "python/**"
- "scripts/**"
- "test/**"
concurrency:
group: vllm-dependency-test-${{ github.ref }}
......
......@@ -17,6 +17,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle
[project.optional-dependencies]
runtime_common = [
"compressed-tensors",
"datasets",
"decord",
"fastapi",
......@@ -56,7 +57,12 @@ srt = [
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
srt_hip = [
"sglang[runtime_common]",
"torch",
"vllm==0.6.7.dev2",
"outlines==0.1.11"
]
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
......
......@@ -22,11 +22,7 @@ import torch
from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.layers.quantization import (
BASE_QUANTIZATION_METHODS,
QUANTIZATION_METHODS,
VLLM_AVAILABLE,
)
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.utils import get_bool_env_var, is_hip
logger = logging.getLogger(__name__)
......@@ -239,12 +235,7 @@ class ModelConfig:
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
# Select supported quantization methods based on vllm availability
if VLLM_AVAILABLE:
supported_quantization = [*QUANTIZATION_METHODS]
else:
supported_quantization = [*BASE_QUANTIZATION_METHODS]
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = [
"awq",
"gptq",
......@@ -282,11 +273,7 @@ class ModelConfig:
quant_method = quant_cfg.get("quant_method", "").lower()
# Detect which checkpoint is it
# Only iterate through currently available quantization methods
available_methods = (
QUANTIZATION_METHODS if VLLM_AVAILABLE else BASE_QUANTIZATION_METHODS
)
for _, method in available_methods.items():
for _, method in QUANTIZATION_METHODS.items():
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization
)
......
......@@ -17,12 +17,12 @@ from typing import Callable, Optional
import torch
import torch.nn.functional as F
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
expert_distribution_recorder = ExpertDistributionRecorder()
......
......@@ -9,12 +9,24 @@ import torch
try:
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMoEMethod,
)
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsW8A8Fp8MoEMethod,
CompressedTensorsWNA16MoEMethod,
)
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config,
)
......@@ -22,24 +34,24 @@ try:
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
# Define empty classes as placeholders when vllm is not available
class DummyConfig:
pass
def override_quantization_method(self, *args, **kwargs):
return None
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
DummyConfig
)
DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
GPTQMarlin24Config
) = DummyConfig
MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
DeepSpeedFPConfig
) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
MarlinConfig
) = QQQConfig = Int8TpuConfig = DummyConfig
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
......@@ -47,9 +59,14 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
CompressedTensorsConfig,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
)
# Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
......@@ -61,26 +78,25 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"compressed-tensors": CompressedTensorsConfig,
}
# Add vllm-dependent methods if available
QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy()
if VLLM_AVAILABLE:
VLLM_QUANTIZATION_METHODS = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"awq_marlin": AWQMarlinConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
}
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
# VLLM-dependent quantization methods
VLLM_QUANTIZATION_METHODS = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"awq_marlin": AWQMarlinConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
}
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
......@@ -89,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
f"Invalid quantization method: {quantization}. "
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
)
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
raise ValueError(
f"{quantization} quantization requires some operators from vllm. "
"Pleaes install vllm by `pip install vllm==0.7.2`"
)
return QUANTIZATION_METHODS[quantization]
......@@ -153,13 +175,6 @@ def get_linear_quant_method(
prefix: str,
linear_method_cls: type,
):
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
)
cloned_config = deepcopy(config)
parallel_lm_head_quantized = (
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
......@@ -186,31 +201,17 @@ def get_linear_quant_method(
def gptq_get_quant_method(self, layer, prefix):
if not VLLM_AVAILABLE:
return None
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
try:
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
if isinstance(self, GPTQConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
)
elif isinstance(self, GPTQMarlinConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
if isinstance(self, GPTQConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
)
elif isinstance(self, GPTQMarlinConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
)
except ImportError:
pass
return None
......@@ -229,33 +230,28 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
builtins.isinstance = original_isinstance
return
try:
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE as PatchedFusedMoE,
)
from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
)
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
)
def patched_isinstance(obj, classinfo):
if classinfo is LinearBase:
return original_isinstance(obj, PatchedLinearBase)
if classinfo is FusedMoE:
return original_isinstance(obj, PatchedFusedMoE)
if classinfo is VocabParallelEmbedding:
return original_isinstance(obj, PatchedVocabParallelEmbedding)
return original_isinstance(obj, classinfo)
builtins.isinstance = patched_isinstance
except ImportError:
return
def patched_isinstance(obj, classinfo):
if classinfo is LinearBase:
return original_isinstance(obj, PatchedLinearBase)
if classinfo is FusedMoE:
return original_isinstance(obj, PatchedFusedMoE)
if classinfo is VocabParallelEmbedding:
return original_isinstance(obj, PatchedVocabParallelEmbedding)
return original_isinstance(obj, classinfo)
builtins.isinstance = patched_isinstance
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
......@@ -263,91 +259,64 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert sglang arguments to vllm arguments.
"""
if not VLLM_AVAILABLE:
return
try:
original_apply = class_obj.apply
sig = inspect.signature(original_apply)
param_names = list(sig.parameters.keys())
has_correction_bias = "e_score_correction_bias" in param_names
def new_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
):
assert activation == "silu"
assert inplace and not no_combine
kwargs = {
"self": self,
"layer": layer,
"x": x,
"router_logits": router_logits,
"top_k": top_k,
"renormalize": renormalize,
"use_grouped_topk": use_grouped_topk,
"topk_group": topk_group,
"num_expert_group": num_expert_group,
"custom_routing_function": custom_routing_function,
}
if correction_bias is not None:
if not has_correction_bias:
raise ValueError(
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
)
kwargs["e_score_correction_bias"] = correction_bias
return original_apply(**kwargs)
setattr(class_obj, "apply", new_apply)
except (ImportError, AttributeError):
return
original_apply = class_obj.apply
sig = inspect.signature(original_apply)
param_names = list(sig.parameters.keys())
has_correction_bias = "e_score_correction_bias" in param_names
def new_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
):
assert activation == "silu"
assert inplace and not no_combine
kwargs = {
"self": self,
"layer": layer,
"x": x,
"router_logits": router_logits,
"top_k": top_k,
"renormalize": renormalize,
"use_grouped_topk": use_grouped_topk,
"topk_group": topk_group,
"num_expert_group": num_expert_group,
"custom_routing_function": custom_routing_function,
}
if correction_bias is not None:
if not has_correction_bias:
raise ValueError(
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
)
kwargs["e_score_correction_bias"] = correction_bias
return original_apply(**kwargs)
setattr(class_obj, "apply", new_apply)
def monkey_patch_quant_configs():
"""Apply all monkey patches in one place."""
if not VLLM_AVAILABLE:
return
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
try:
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsW8A8Fp8MoEMethod,
CompressedTensorsWNA16MoEMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinMoEMethod,
)
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
monkey_patch_moe_apply(AWQMoEMethod)
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
except ImportError:
return
monkey_patch_moe_apply(AWQMoEMethod)
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
# Only apply monkey patches if vllm is available
if VLLM_AVAILABLE:
monkey_patch_quant_configs()
__all__ = [
"get_quantization_config",
"QUANTIZATION_METHODS",
]
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
import torch
from sgl_kernel import awq_dequantize
......
......@@ -24,6 +24,7 @@ import triton.language as tl
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
get_device_core_count,
get_device_name,
get_device_sm,
......@@ -43,7 +44,7 @@ if _is_cuda:
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
sm_version = get_device_sm()
if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")):
if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
_enable_jit_deepgemm = True
......
......@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
try:
import vllm
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
)
from vllm.scalar_type import scalar_types
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
class scalar_types:
uint4b8 = "uint4b8"
uint8b128 = "uint8b128"
logger = logging.getLogger(__name__)
......@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["GPTQLinearMethod"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
) -> Optional[GPTQLinearMethod]:
# Delay the import to avoid circular dependency
from sglang.srt.layers.quantization import get_linear_quant_method
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
......@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
if VLLM_AVAILABLE:
from vllm.scalar_type import scalar_types
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
else:
raise ImportError("vllm is not installed")
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
def __init__(
self,
......@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
)
# (num_bits, is_sym) -> quant_type
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
def __repr__(self) -> str:
......@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
) -> Optional[QuantizeMethodBase]:
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import get_linear_quant_method
......@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
if not VLLM_AVAILABLE:
return False
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
)
if not _is_cuda:
return False
......@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["MarlinLinearMethod"]:
if not VLLM_AVAILABLE:
raise ImportError("vllm is not installed")
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
# Delay import to avoid circular dependency
) -> Optional[MarlinLinearMethod]:
# Delay the import to avoid circular dependency
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
if isinstance(layer, LinearBase) or (
......
......@@ -53,8 +53,6 @@ class TpModelWorker:
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
):
self.worker = self
# Parse args
self.tp_rank = tp_rank
......@@ -134,6 +132,9 @@ class TpModelWorker:
)[0]
set_random_seed(self.random_seed)
# A reference make this class has the same member as TpModelWorkerClient
self.worker = self
def get_worker_info(self):
return (
self.max_total_num_tokens,
......
......@@ -73,7 +73,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
from sglang.srt.utils import add_prefix, is_cuda, is_hip
_is_hip = is_hip()
_is_cuda = is_cuda()
......
#!/bin/bash
set -euxo pipefail
# Install the dependency in CI.
set -euxo pipefail
# Use repo from environment variable, passed from GitHub Actions
# Use repo from environment variables, passed from GitHub Actions
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}"
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
......@@ -17,17 +15,12 @@ pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2
rm -rf /root/.cache/flashinfer
# Force reinstall flashinfer and torch_memory_saver
pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps
pip install sgl-kernel==0.0.5.post3 --force-reinstall
pip install torch_memory_saver --force-reinstall
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets
pip install torch_memory_saver
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets timm
# For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12
# For DeepSeek-VL2
pip install timm
pip install sgl-kernel==0.0.5.post3 --force-reinstall
pip uninstall vllm -y || true
......@@ -45,7 +45,7 @@ class TestEAGLEEngine(CustomTestCase):
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 4,
}
NUM_CONFIGS = 3
NUM_CONFIGS = 2
def setUp(self):
self.prompt = "Today is a sunny day and I like"
......@@ -61,8 +61,6 @@ class TestEAGLEEngine(CustomTestCase):
configs = [
# Basic config
self.BASE_CONFIG,
# Disable cuda graph
{**self.BASE_CONFIG, "disable_cuda_graph": True},
# Chunked prefill
{**self.BASE_CONFIG, "chunked_prefill_size": 4},
]
......
......@@ -28,7 +28,7 @@ class TestTritonAttnBackend(CustomTestCase):
"triton",
"--enable-torch-compile",
"--cuda-graph-max-bs",
16,
4,
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment