Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
74e0ac1d
Unverified
Commit
74e0ac1d
authored
Mar 28, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 28, 2025
Browse files
Clean up `import vllm` in quantization/__init__.py (#4834)
parent
ef9a378a
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
195 additions
and
258 deletions
+195
-258
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+4
-8
.github/workflows/vllm-dependency-test.yml
.github/workflows/vllm-dependency-test.yml
+4
-8
python/pyproject.toml
python/pyproject.toml
+7
-1
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+3
-16
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+1
-1
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+132
-163
python/sglang/srt/layers/quantization/awq.py
python/sglang/srt/layers/quantization/awq.py
+1
-1
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+2
-1
python/sglang/srt/layers/quantization/gptq.py
python/sglang/srt/layers/quantization/gptq.py
+30
-40
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+5
-12
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+1
-3
test/srt/test_triton_attention_backend.py
test/srt/test_triton_attention_backend.py
+1
-1
No files found.
.github/workflows/pr-test.yml
View file @
74e0ac1d
...
@@ -4,19 +4,15 @@ on:
...
@@ -4,19 +4,15 @@ on:
push
:
push
:
branches
:
[
main
]
branches
:
[
main
]
paths
:
paths
:
-
"
python/pyproject.toml"
-
"
python/**"
-
"
python/sglang/**"
-
"
test/**"
-
"
docs/**"
-
"
scripts/**"
-
"
scripts/**"
-
"
test/**"
pull_request
:
pull_request
:
branches
:
[
main
]
branches
:
[
main
]
paths
:
paths
:
-
"
python/pyproject.toml"
-
"
python/**"
-
"
python/sglang/**"
-
"
test/**"
-
"
docs/**"
-
"
scripts/**"
-
"
scripts/**"
-
"
test/**"
workflow_dispatch
:
workflow_dispatch
:
inputs
:
inputs
:
version
:
version
:
...
...
.github/workflows/vllm-dependency-test.yml
View file @
74e0ac1d
...
@@ -4,19 +4,15 @@ on:
...
@@ -4,19 +4,15 @@ on:
push
:
push
:
branches
:
[
main
]
branches
:
[
main
]
paths
:
paths
:
-
"
python/pyproject.toml"
-
"
python/**"
-
"
python/sglang/**"
-
"
test/**"
-
"
docs/**"
-
"
scripts/**"
-
"
scripts/**"
-
"
test/**"
pull_request
:
pull_request
:
branches
:
[
main
]
branches
:
[
main
]
paths
:
paths
:
-
"
python/pyproject.toml"
-
"
python/**"
-
"
python/sglang/**"
-
"
test/**"
-
"
docs/**"
-
"
scripts/**"
-
"
scripts/**"
-
"
test/**"
concurrency
:
concurrency
:
group
:
vllm-dependency-test-${{ github.ref }}
group
:
vllm-dependency-test-${{ github.ref }}
...
...
python/pyproject.toml
View file @
74e0ac1d
...
@@ -17,6 +17,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle
...
@@ -17,6 +17,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle
[project.optional-dependencies]
[project.optional-dependencies]
runtime_common
=
[
runtime_common
=
[
"compressed-tensors"
,
"datasets"
,
"datasets"
,
"decord"
,
"decord"
,
"fastapi"
,
"fastapi"
,
...
@@ -56,7 +57,12 @@ srt = [
...
@@ -56,7 +57,12 @@ srt = [
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
srt_hip
=
["sglang[runtime_common]
", "
torch
", "
vllm==
0.6.7
.dev
2
", "
outlines==
0.1.11
"]
srt_hip
=
[
"sglang[runtime_common]"
,
"torch"
,
"vllm==0.6.7.dev2"
,
"outlines==0.1.11"
]
# xpu is not enabled in public vllm and torch whl,
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
...
...
python/sglang/srt/configs/model_config.py
View file @
74e0ac1d
...
@@ -22,11 +22,7 @@ import torch
...
@@ -22,11 +22,7 @@ import torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
sglang.srt.layers.quantization
import
(
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
BASE_QUANTIZATION_METHODS
,
QUANTIZATION_METHODS
,
VLLM_AVAILABLE
,
)
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -239,12 +235,7 @@ class ModelConfig:
...
@@ -239,12 +235,7 @@ class ModelConfig:
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
# Select supported quantization methods based on vllm availability
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
if
VLLM_AVAILABLE
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
else
:
supported_quantization
=
[
*
BASE_QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
rocm_supported_quantization
=
[
"awq"
,
"awq"
,
"gptq"
,
"gptq"
,
...
@@ -282,11 +273,7 @@ class ModelConfig:
...
@@ -282,11 +273,7 @@ class ModelConfig:
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
# Detect which checkpoint is it
# Detect which checkpoint is it
# Only iterate through currently available quantization methods
for
_
,
method
in
QUANTIZATION_METHODS
.
items
():
available_methods
=
(
QUANTIZATION_METHODS
if
VLLM_AVAILABLE
else
BASE_QUANTIZATION_METHODS
)
for
_
,
method
in
available_methods
.
items
():
quantization_override
=
method
.
override_quantization_method
(
quantization_override
=
method
.
override_quantization_method
(
quant_cfg
,
self
.
quantization
quant_cfg
,
self
.
quantization
)
)
...
...
python/sglang/srt/layers/moe/topk.py
View file @
74e0ac1d
...
@@ -17,12 +17,12 @@ from typing import Callable, Optional
...
@@ -17,12 +17,12 @@ from typing import Callable, Optional
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.utils
import
get_compiler_backend
,
is_cuda
,
is_hip
from
sglang.srt.utils
import
get_compiler_backend
,
is_cuda
,
is_hip
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
expert_distribution_recorder
=
ExpertDistributionRecorder
()
expert_distribution_recorder
=
ExpertDistributionRecorder
()
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
74e0ac1d
...
@@ -9,12 +9,24 @@ import torch
...
@@ -9,12 +9,24 @@ import torch
try
:
try
:
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
(
AWQMarlinConfig
,
AWQMoEMethod
,
)
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
CompressedTensorsW8A8Fp8MoEMethod
,
CompressedTensorsWNA16MoEMethod
,
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.experts_int8
import
ExpertsInt8Config
from
vllm.model_executor.layers.quantization.experts_int8
import
ExpertsInt8Config
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
,
GPTQMarlin24Config
,
)
)
...
@@ -22,24 +34,24 @@ try:
...
@@ -22,24 +34,24 @@ try:
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
from
sglang.srt.layers.quantization.gptq
import
GPTQConfig
,
GPTQMarlinConfig
VLLM_AVAILABLE
=
True
VLLM_AVAILABLE
=
True
except
ImportError
:
except
ImportError
:
VLLM_AVAILABLE
=
False
VLLM_AVAILABLE
=
False
# Define empty classes as placeholders when vllm is not available
# Define empty classes as placeholders when vllm is not available
class
DummyConfig
:
class
DummyConfig
:
pass
def
override_quantization_method
(
self
,
*
args
,
**
kwargs
):
return
None
AQLMConfig
=
AWQMarlinConfig
=
BitsAndBytesConfig
=
CompressedTensorsConfig
=
(
AQLMConfig
=
AWQMarlinConfig
=
BitsAndBytesConfig
=
CompressedTensorsConfig
=
(
DummyConfig
DeepSpeedFPConfig
)
)
=
ExpertsInt8Config
=
FBGEMMFp8Config
=
GGUFConfig
=
GPTQMarlin24Config
=
(
DeepSpeedFPConfig
=
ExpertsInt8Config
=
FBGEMMFp8Config
=
GGUFConfig
=
(
MarlinConfig
GPTQMarlin24Config
)
=
QQQConfig
=
Int8TpuConfig
=
DummyConfig
)
=
DummyConfig
MarlinConfig
=
QQQConfig
=
Int8TpuConfig
=
DummyConfig
from
sglang.srt.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.quantization.awq
import
AWQConfig
from
sglang.srt.layers.quantization.awq
import
AWQConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.blockwise_int8
import
BlockInt8Config
from
sglang.srt.layers.quantization.blockwise_int8
import
BlockInt8Config
...
@@ -47,9 +59,14 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
...
@@ -47,9 +59,14 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
CompressedTensorsConfig
,
CompressedTensorsConfig
,
)
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.gptq
import
GPTQConfig
,
GPTQMarlinConfig
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptFp8Config
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptFp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
UnquantizedEmbeddingMethod
,
)
# Base quantization methods that don't depend on vllm
# Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
BASE_QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
...
@@ -61,26 +78,25 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -61,26 +78,25 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"compressed-tensors"
:
CompressedTensorsConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
}
}
# Add vllm-dependent methods if available
# VLLM-dependent quantization methods
QUANTIZATION_METHODS
=
BASE_QUANTIZATION_METHODS
.
copy
()
VLLM_QUANTIZATION_METHODS
=
{
if
VLLM_AVAILABLE
:
"aqlm"
:
AQLMConfig
,
VLLM_QUANTIZATION_METHODS
=
{
"awq"
:
AWQConfig
,
"aqlm"
:
AQLMConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"awq"
:
AWQConfig
,
"tpu_int8"
:
Int8TpuConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"tpu_int8"
:
Int8TpuConfig
,
"marlin"
:
MarlinConfig
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"gguf"
:
GGUFConfig
,
"marlin"
:
MarlinConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gguf"
:
GGUFConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"qqq"
:
QQQConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"qqq"
:
QQQConfig
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"gptq"
:
GPTQConfig
,
"gptq_marlin"
:
GPTQMarlinConfig
,
}
"gptq"
:
GPTQConfig
,
}
QUANTIZATION_METHODS
=
{
**
BASE_QUANTIZATION_METHODS
,
**
VLLM_QUANTIZATION_METHODS
}
QUANTIZATION_METHODS
.
update
(
VLLM_QUANTIZATION_METHODS
)
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
...
@@ -89,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...
@@ -89,6 +105,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
f
"Invalid quantization method:
{
quantization
}
. "
f
"Invalid quantization method:
{
quantization
}
. "
f
"Available methods:
{
list
(
QUANTIZATION_METHODS
.
keys
())
}
"
f
"Available methods:
{
list
(
QUANTIZATION_METHODS
.
keys
())
}
"
)
)
if
quantization
in
VLLM_QUANTIZATION_METHODS
and
not
VLLM_AVAILABLE
:
raise
ValueError
(
f
"
{
quantization
}
quantization requires some operators from vllm. "
"Pleaes install vllm by `pip install vllm==0.7.2`"
)
return
QUANTIZATION_METHODS
[
quantization
]
return
QUANTIZATION_METHODS
[
quantization
]
...
@@ -153,13 +175,6 @@ def get_linear_quant_method(
...
@@ -153,13 +175,6 @@ def get_linear_quant_method(
prefix
:
str
,
prefix
:
str
,
linear_method_cls
:
type
,
linear_method_cls
:
type
,
):
):
from
sglang.srt.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
UnquantizedEmbeddingMethod
,
)
cloned_config
=
deepcopy
(
config
)
cloned_config
=
deepcopy
(
config
)
parallel_lm_head_quantized
=
(
parallel_lm_head_quantized
=
(
isinstance
(
layer
,
ParallelLMHead
)
and
cloned_config
.
lm_head_quantized
isinstance
(
layer
,
ParallelLMHead
)
and
cloned_config
.
lm_head_quantized
...
@@ -186,31 +201,17 @@ def get_linear_quant_method(
...
@@ -186,31 +201,17 @@ def get_linear_quant_method(
def
gptq_get_quant_method
(
self
,
layer
,
prefix
):
def
gptq_get_quant_method
(
self
,
layer
,
prefix
):
if
not
VLLM_AVAILABLE
:
if
isinstance
(
layer
,
FusedMoE
)
:
return
None
return
GPTQMarlinMoEMethod
(
self
)
try
:
if
isinstance
(
self
,
GPTQConfig
):
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
return
get_linear_quant_method
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
self
,
layer
,
prefix
=
prefix
,
linear_method_cls
=
GPTQLinearMethod
GPTQMarlinLinearMethod
,
)
GPTQMarlinMoEMethod
,
elif
isinstance
(
self
,
GPTQMarlinConfig
):
return
get_linear_quant_method
(
self
,
layer
,
prefix
=
prefix
,
linear_method_cls
=
GPTQMarlinLinearMethod
)
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
if
isinstance
(
layer
,
FusedMoE
):
return
GPTQMarlinMoEMethod
(
self
)
if
isinstance
(
self
,
GPTQConfig
):
return
get_linear_quant_method
(
self
,
layer
,
prefix
=
prefix
,
linear_method_cls
=
GPTQLinearMethod
)
elif
isinstance
(
self
,
GPTQMarlinConfig
):
return
get_linear_quant_method
(
self
,
layer
,
prefix
=
prefix
,
linear_method_cls
=
GPTQMarlinLinearMethod
)
except
ImportError
:
pass
return
None
return
None
...
@@ -229,33 +230,28 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
...
@@ -229,33 +230,28 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
builtins
.
isinstance
=
original_isinstance
builtins
.
isinstance
=
original_isinstance
return
return
try
:
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.layers.linear
import
LinearBase
as
PatchedLinearBase
from
sglang.srt.layers.linear
import
LinearBase
as
PatchedLinearBase
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
as
PatchedFusedMoE
FusedMoE
as
PatchedFusedMoE
,
from
sglang.srt.layers.vocab_parallel_embedding
import
(
)
VocabParallelEmbedding
as
PatchedVocabParallelEmbedding
,
from
sglang.srt.layers.vocab_parallel_embedding
import
(
)
VocabParallelEmbedding
as
PatchedVocabParallelEmbedding
,
)
def
patched_isinstance
(
obj
,
classinfo
):
def
patched_isinstance
(
obj
,
classinfo
):
if
classinfo
is
LinearBase
:
if
classinfo
is
LinearBase
:
return
original_isinstance
(
obj
,
PatchedLinearBase
)
return
original_isinstance
(
obj
,
PatchedLinearBase
)
if
classinfo
is
FusedMoE
:
if
classinfo
is
FusedMoE
:
return
original_isinstance
(
obj
,
PatchedFusedMoE
)
return
original_isinstance
(
obj
,
PatchedFusedMoE
)
if
classinfo
is
VocabParallelEmbedding
:
if
classinfo
is
VocabParallelEmbedding
:
return
original_isinstance
(
obj
,
PatchedVocabParallelEmbedding
)
return
original_isinstance
(
obj
,
PatchedVocabParallelEmbedding
)
return
original_isinstance
(
obj
,
classinfo
)
return
original_isinstance
(
obj
,
classinfo
)
builtins
.
isinstance
=
patched_isinstance
builtins
.
isinstance
=
patched_isinstance
except
ImportError
:
return
def
monkey_patch_moe_apply
(
class_obj
:
"FusedMoEMethodBase"
):
def
monkey_patch_moe_apply
(
class_obj
:
"FusedMoEMethodBase"
):
...
@@ -263,91 +259,64 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
...
@@ -263,91 +259,64 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert sglang arguments to vllm arguments.
Convert sglang arguments to vllm arguments.
"""
"""
if
not
VLLM_AVAILABLE
:
original_apply
=
class_obj
.
apply
return
sig
=
inspect
.
signature
(
original_apply
)
param_names
=
list
(
sig
.
parameters
.
keys
())
try
:
has_correction_bias
=
"e_score_correction_bias"
in
param_names
original_apply
=
class_obj
.
apply
sig
=
inspect
.
signature
(
original_apply
)
def
new_apply
(
param_names
=
list
(
sig
.
parameters
.
keys
())
self
,
has_correction_bias
=
"e_score_correction_bias"
in
param_names
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
def
new_apply
(
router_logits
:
torch
.
Tensor
,
self
,
top_k
:
int
,
layer
:
torch
.
nn
.
Module
,
renormalize
:
bool
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
router_logits
:
torch
.
Tensor
,
topk_group
:
Optional
[
int
]
=
None
,
top_k
:
int
,
num_expert_group
:
Optional
[
int
]
=
None
,
renormalize
:
bool
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
use_grouped_topk
:
bool
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
activation
:
str
=
"silu"
,
num_expert_group
:
Optional
[
int
]
=
None
,
inplace
:
bool
=
True
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
no_combine
:
bool
=
False
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
activation
:
str
=
"silu"
,
assert
activation
==
"silu"
inplace
:
bool
=
True
,
assert
inplace
and
not
no_combine
no_combine
:
bool
=
False
,
):
kwargs
=
{
assert
activation
==
"silu"
"self"
:
self
,
assert
inplace
and
not
no_combine
"layer"
:
layer
,
"x"
:
x
,
kwargs
=
{
"router_logits"
:
router_logits
,
"self"
:
self
,
"top_k"
:
top_k
,
"layer"
:
layer
,
"renormalize"
:
renormalize
,
"x"
:
x
,
"use_grouped_topk"
:
use_grouped_topk
,
"router_logits"
:
router_logits
,
"topk_group"
:
topk_group
,
"top_k"
:
top_k
,
"num_expert_group"
:
num_expert_group
,
"renormalize"
:
renormalize
,
"custom_routing_function"
:
custom_routing_function
,
"use_grouped_topk"
:
use_grouped_topk
,
}
"topk_group"
:
topk_group
,
if
correction_bias
is
not
None
:
"num_expert_group"
:
num_expert_group
,
if
not
has_correction_bias
:
"custom_routing_function"
:
custom_routing_function
,
raise
ValueError
(
}
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
if
correction_bias
is
not
None
:
)
if
not
has_correction_bias
:
kwargs
[
"e_score_correction_bias"
]
=
correction_bias
raise
ValueError
(
return
original_apply
(
**
kwargs
)
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
)
setattr
(
class_obj
,
"apply"
,
new_apply
)
kwargs
[
"e_score_correction_bias"
]
=
correction_bias
return
original_apply
(
**
kwargs
)
setattr
(
class_obj
,
"apply"
,
new_apply
)
except
(
ImportError
,
AttributeError
):
return
def
monkey_patch_quant_configs
():
def
monkey_patch_quant_configs
():
"""Apply all monkey patches in one place."""
"""Apply all monkey patches in one place."""
if
not
VLLM_AVAILABLE
:
setattr
(
GPTQMarlinConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
return
setattr
(
GPTQConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
try
:
monkey_patch_moe_apply
(
AWQMoEMethod
)
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMoEMethod
monkey_patch_moe_apply
(
GPTQMarlinMoEMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
monkey_patch_moe_apply
(
CompressedTensorsW8A8Fp8MoEMethod
)
CompressedTensorsW8A8Fp8MoEMethod
,
monkey_patch_moe_apply
(
CompressedTensorsWNA16MoEMethod
)
CompressedTensorsWNA16MoEMethod
,
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinMoEMethod
,
)
setattr
(
GPTQMarlinConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
setattr
(
GPTQConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
monkey_patch_moe_apply
(
AWQMoEMethod
)
monkey_patch_moe_apply
(
GPTQMarlinMoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsW8A8Fp8MoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsWNA16MoEMethod
)
except
ImportError
:
return
# Only apply monkey patches if vllm is available
# Only apply monkey patches if vllm is available
if
VLLM_AVAILABLE
:
if
VLLM_AVAILABLE
:
monkey_patch_quant_configs
()
monkey_patch_quant_configs
()
__all__
=
[
"get_quantization_config"
,
"QUANTIZATION_METHODS"
,
]
python/sglang/srt/layers/quantization/awq.py
View file @
74e0ac1d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
logging
import
logging
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
from
sgl_kernel
import
awq_dequantize
from
sgl_kernel
import
awq_dequantize
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
74e0ac1d
...
@@ -24,6 +24,7 @@ import triton.language as tl
...
@@ -24,6 +24,7 @@ import triton.language as tl
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
direct_register_custom_op
,
direct_register_custom_op
,
get_bool_env_var
,
get_device_core_count
,
get_device_core_count
,
get_device_name
,
get_device_name
,
get_device_sm
,
get_device_sm
,
...
@@ -43,7 +44,7 @@ if _is_cuda:
...
@@ -43,7 +44,7 @@ if _is_cuda:
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
sm_version
=
get_device_sm
()
sm_version
=
get_device_sm
()
if
sm_version
>=
90
and
int
(
os
.
getenv
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
"1"
)
):
if
sm_version
>=
90
and
get_bool_env_var
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
default
=
"true"
):
_enable_jit_deepgemm
=
True
_enable_jit_deepgemm
=
True
...
...
python/sglang/srt/layers/quantization/gptq.py
View file @
74e0ac1d
...
@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
...
@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
try
:
try
:
import
vllm
from
vllm.model_executor.layers.quantization.base_config
import
QuantizeMethodBase
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinLinearMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supported
,
)
from
vllm.scalar_type
import
scalar_types
VLLM_AVAILABLE
=
True
VLLM_AVAILABLE
=
True
except
ImportError
:
except
ImportError
:
VLLM_AVAILABLE
=
False
VLLM_AVAILABLE
=
False
GPTQLinearMethod
=
MarlinLinearMethod
=
QuantizeMethodBase
=
Any
class
scalar_types
:
uint4b8
=
"uint4b8"
uint8b128
=
"uint8b128"
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
...
@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"GPTQLinearMethod"
]:
)
->
Optional
[
GPTQLinearMethod
]:
if
not
VLLM_AVAILABLE
:
# Delay the import to avoid circular dependency
raise
ImportError
(
"vllm is not installed"
)
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
sglang.srt.layers.quantization
import
get_linear_quant_method
from
sglang.srt.layers.quantization
import
get_linear_quant_method
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQLinearMethod
)
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQLinearMethod
)
...
@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
...
@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
class
GPTQMarlinConfig
(
QuantizationConfig
):
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
"""Config class for GPTQ Marlin"""
if
VLLM_AVAILABLE
:
# (num_bits, is_sym) -> quant_type
from
vllm.scalar_type
import
scalar_types
TYPE_MAP
=
{
(
4
,
True
):
scalar_types
.
uint4b8
,
# (num_bits, is_sym) -> quant_type
(
8
,
True
):
scalar_types
.
uint8b128
,
TYPE_MAP
=
{
}
(
4
,
True
):
scalar_types
.
uint4b8
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
else
:
raise
ImportError
(
"vllm is not installed"
)
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
"Unsupported quantization config: "
f
"bits=
{
weight_bits
}
, sym=
{
is_sym
}
"
"Unsupported quantization config: "
f
"bits=
{
weight_bits
}
, sym=
{
is_sym
}
"
)
)
# (num_bits, is_sym) -> quant_type
self
.
quant_type
=
self
.
TYPE_MAP
[(
weight_bits
,
is_sym
)]
self
.
quant_type
=
self
.
TYPE_MAP
[(
weight_bits
,
is_sym
)]
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
...
@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
QuantizeMethodBase
]:
if
not
VLLM_AVAILABLE
:
# Delay the import to avoid circular dependency
raise
ImportError
(
"vllm is not installed"
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
)
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization
import
get_linear_quant_method
from
sglang.srt.layers.quantization
import
get_linear_quant_method
...
@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
is_gptq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
def
is_gptq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
if
not
VLLM_AVAILABLE
:
return
False
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
)
num_bits
=
quant_config
.
get
(
"bits"
)
group_size
=
quant_config
.
get
(
"group_size"
)
group_size
=
quant_config
.
get
(
"group_size"
)
sym
=
quant_config
.
get
(
"sym"
)
sym
=
quant_config
.
get
(
"sym"
)
desc_act
=
quant_config
.
get
(
"desc_act"
)
desc_act
=
quant_config
.
get
(
"desc_act"
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supported
,
)
if
not
_is_cuda
:
if
not
_is_cuda
:
return
False
return
False
...
@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
...
@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"MarlinLinearMethod"
]:
)
->
Optional
[
MarlinLinearMethod
]:
if
not
VLLM_AVAILABLE
:
# Delay the import to avoid circular dependency
raise
ImportError
(
"vllm is not installed"
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinLinearMethod
# Delay import to avoid circular dependency
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
if
isinstance
(
layer
,
LinearBase
)
or
(
if
isinstance
(
layer
,
LinearBase
)
or
(
...
...
python/sglang/srt/managers/tp_worker.py
View file @
74e0ac1d
...
@@ -53,8 +53,6 @@ class TpModelWorker:
...
@@ -53,8 +53,6 @@ class TpModelWorker:
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
):
):
self
.
worker
=
self
# Parse args
# Parse args
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
...
@@ -134,6 +132,9 @@ class TpModelWorker:
...
@@ -134,6 +132,9 @@ class TpModelWorker:
)[
0
]
)[
0
]
set_random_seed
(
self
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
# A reference make this class has the same member as TpModelWorkerClient
self
.
worker
=
self
def
get_worker_info
(
self
):
def
get_worker_info
(
self
):
return
(
return
(
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
74e0ac1d
...
@@ -73,7 +73,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
...
@@ -73,7 +73,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_cuda_available
,
is_hip
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
...
scripts/ci_install_dependency.sh
View file @
74e0ac1d
#!/bin/bash
#!/bin/bash
set
-euxo
pipefail
# Install the dependency in CI.
# Install the dependency in CI.
set
-euxo
pipefail
# Use repo from environment variables, passed from GitHub Actions
# Use repo from environment variable, passed from GitHub Actions
FLASHINFER_REPO
=
"
${
FLASHINFER_REPO
:-
https
://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
}
"
FLASHINFER_REPO
=
"
${
FLASHINFER_REPO
:-
https
://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
}
"
SCRIPT_DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
SCRIPT_DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
...
@@ -17,17 +15,12 @@ pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2
...
@@ -17,17 +15,12 @@ pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2
rm
-rf
/root/.cache/flashinfer
rm
-rf
/root/.cache/flashinfer
# Force reinstall flashinfer and torch_memory_saver
# Force reinstall flashinfer and torch_memory_saver
pip
install
flashinfer_python
==
0.2.3
--find-links
${
FLASHINFER_REPO
}
--force-reinstall
--no-deps
pip
install
flashinfer_python
==
0.2.3
--find-links
${
FLASHINFER_REPO
}
--force-reinstall
--no-deps
pip
install
sgl-kernel
==
0.0.5.post3
--force-reinstall
pip
install
torch_memory_saver
--force-reinstall
pip
install
torch_memory_saver
pip
install
transformers
==
4.50.0 sentence_transformers
accelerate
==
1.4.0 peft pandas datasets timm
pip
install
transformers
==
4.50.0 sentence_transformers
accelerate
==
1.4.0 peft pandas datasets
# For compling xgrammar kernels
# For compling xgrammar kernels
pip
install
cuda-python nvidia-cuda-nvrtc-cu12
pip
install
cuda-python nvidia-cuda-nvrtc-cu12
# For DeepSeek-VL2
pip
install
timm
pip
install
sgl-kernel
==
0.0.5.post3
--force-reinstall
pip uninstall vllm
-y
||
true
pip uninstall vllm
-y
||
true
test/srt/test_eagle_infer.py
View file @
74e0ac1d
...
@@ -45,7 +45,7 @@ class TestEAGLEEngine(CustomTestCase):
...
@@ -45,7 +45,7 @@ class TestEAGLEEngine(CustomTestCase):
"mem_fraction_static"
:
0.7
,
"mem_fraction_static"
:
0.7
,
"cuda_graph_max_bs"
:
4
,
"cuda_graph_max_bs"
:
4
,
}
}
NUM_CONFIGS
=
3
NUM_CONFIGS
=
2
def
setUp
(
self
):
def
setUp
(
self
):
self
.
prompt
=
"Today is a sunny day and I like"
self
.
prompt
=
"Today is a sunny day and I like"
...
@@ -61,8 +61,6 @@ class TestEAGLEEngine(CustomTestCase):
...
@@ -61,8 +61,6 @@ class TestEAGLEEngine(CustomTestCase):
configs
=
[
configs
=
[
# Basic config
# Basic config
self
.
BASE_CONFIG
,
self
.
BASE_CONFIG
,
# Disable cuda graph
{
**
self
.
BASE_CONFIG
,
"disable_cuda_graph"
:
True
},
# Chunked prefill
# Chunked prefill
{
**
self
.
BASE_CONFIG
,
"chunked_prefill_size"
:
4
},
{
**
self
.
BASE_CONFIG
,
"chunked_prefill_size"
:
4
},
]
]
...
...
test/srt/test_triton_attention_backend.py
View file @
74e0ac1d
...
@@ -28,7 +28,7 @@ class TestTritonAttnBackend(CustomTestCase):
...
@@ -28,7 +28,7 @@ class TestTritonAttnBackend(CustomTestCase):
"triton"
,
"triton"
,
"--enable-torch-compile"
,
"--enable-torch-compile"
,
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
16
,
4
,
],
],
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment