Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bf8d07a6
Unverified
Commit
bf8d07a6
authored
Jan 16, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 16, 2025
Browse files
feat: patch linear base (#2915)
parent
ab317936
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
68 additions
and
13 deletions
+68
-13
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+39
-3
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+21
-4
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+5
-2
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+1
-2
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-1
No files found.
python/sglang/srt/layers/linear.py
View file @
bf8d07a6
...
@@ -16,9 +16,6 @@ from vllm.distributed import (
...
@@ -16,9 +16,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
# Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
from
vllm.model_executor.layers.linear
import
LinearBase
from
sglang.srt.layers.parameter
import
(
from
sglang.srt.layers.parameter
import
(
BasevLLMParameter
,
BasevLLMParameter
,
PackedColumnParameter
,
PackedColumnParameter
,
...
@@ -174,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -174,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase):
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
class
LinearBase
(
torch
.
nn
.
Module
):
"""Base linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
# Keep input parameters
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
skip_bias_add
=
skip_bias_add
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
ReplicatedLinear
(
LinearBase
):
class
ReplicatedLinear
(
LinearBase
):
"""Replicated linear layer.
"""Replicated linear layer.
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
bf8d07a6
...
@@ -58,12 +58,11 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...
@@ -58,12 +58,11 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
def
fp8_get_quant_method
(
self
,
layer
,
prefix
):
def
fp8_get_quant_method
(
self
,
layer
,
prefix
):
"""Enhanced get_quant_method for FP8 config."""
"""Enhanced get_quant_method for FP8 config."""
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
,
is_layer_skipped
,
)
)
from
sglang.srt.layers.linear
import
UnquantizedLinearMethod
from
sglang.srt.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
,
Fp8MoEMethod
...
@@ -77,12 +76,12 @@ def fp8_get_quant_method(self, layer, prefix):
...
@@ -77,12 +76,12 @@ def fp8_get_quant_method(self, layer, prefix):
def
gptq_get_quant_method
(
self
,
layer
,
prefix
):
def
gptq_get_quant_method
(
self
,
layer
,
prefix
):
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
,
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
GPTQMarlinMoEMethod
,
)
)
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
...
@@ -93,12 +92,12 @@ def gptq_get_quant_method(self, layer, prefix):
...
@@ -93,12 +92,12 @@ def gptq_get_quant_method(self, layer, prefix):
def
awq_get_quant_method
(
self
,
layer
,
prefix
):
def
awq_get_quant_method
(
self
,
layer
,
prefix
):
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.awq_marlin
import
(
from
vllm.model_executor.layers.quantization.awq_marlin
import
(
AWQMarlinLinearMethod
,
AWQMarlinLinearMethod
,
AWQMoEMethod
,
AWQMoEMethod
,
)
)
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
...
@@ -108,6 +107,23 @@ def awq_get_quant_method(self, layer, prefix):
...
@@ -108,6 +107,23 @@ def awq_get_quant_method(self, layer, prefix):
return
None
return
None
def
patch_vllm_linear_base_isinstance
():
import
builtins
from
vllm.model_executor.layers.linear
import
LinearBase
from
sglang.srt.layers.linear
import
LinearBase
as
PatchedLinearBase
original_isinstance
=
builtins
.
isinstance
def
patched_isinstance
(
obj
,
classinfo
):
if
classinfo
is
LinearBase
:
return
original_isinstance
(
obj
,
PatchedLinearBase
)
return
original_isinstance
(
obj
,
classinfo
)
builtins
.
isinstance
=
patched_isinstance
def
apply_monkey_patches
():
def
apply_monkey_patches
():
"""Apply all monkey patches in one place."""
"""Apply all monkey patches in one place."""
setattr
(
Fp8Config
,
"get_quant_method"
,
fp8_get_quant_method
)
setattr
(
Fp8Config
,
"get_quant_method"
,
fp8_get_quant_method
)
...
@@ -115,6 +131,7 @@ def apply_monkey_patches():
...
@@ -115,6 +131,7 @@ def apply_monkey_patches():
setattr
(
AWQMarlinConfig
,
"get_quant_method"
,
awq_get_quant_method
)
setattr
(
AWQMarlinConfig
,
"get_quant_method"
,
awq_get_quant_method
)
patch_vllm_linear_base_isinstance
()
# Apply patches when module is imported
# Apply patches when module is imported
apply_monkey_patches
()
apply_monkey_patches
()
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
bf8d07a6
...
@@ -9,7 +9,6 @@ from torch.nn import Module
...
@@ -9,7 +9,6 @@ from torch.nn import Module
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
apply_fp8_marlin_linear
,
...
@@ -25,7 +24,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -25,7 +24,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
requantize_with_max_scale
,
requantize_with_max_scale
,
)
)
from
sglang.srt.layers.linear
import
LinearMethodBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
)
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizationConfig
,
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
bf8d07a6
...
@@ -5,14 +5,13 @@ from typing import Any, Dict, List, Optional
...
@@ -5,14 +5,13 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
apply_fp8_linear
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
requantize_with_max_scale
,
requantize_with_max_scale
,
)
)
from
sglang.srt.layers.linear
import
LinearMethodBase
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizationConfig
,
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
bf8d07a6
...
@@ -54,7 +54,7 @@ class W8A8Int8Config(QuantizationConfig):
...
@@ -54,7 +54,7 @@ class W8A8Int8Config(QuantizationConfig):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.model_executor
.layers.linear
import
LinearBase
from
sglang.srt
.layers.linear
import
LinearBase
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
W8A8Int8LinearMethod
(
self
)
return
W8A8Int8LinearMethod
(
self
)
...
...
python/sglang/srt/utils.py
View file @
bf8d07a6
...
@@ -574,13 +574,13 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
...
@@ -574,13 +574,13 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
def
monkey_patch_vllm_gguf_config
():
def
monkey_patch_vllm_gguf_config
():
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.gguf
import
(
from
vllm.model_executor.layers.quantization.gguf
import
(
GGUFConfig
,
GGUFConfig
,
GGUFEmbeddingMethod
,
GGUFEmbeddingMethod
,
GGUFLinearMethod
,
GGUFLinearMethod
,
)
)
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
def
get_quant_method_with_embedding_replaced
(
def
get_quant_method_with_embedding_replaced
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment