Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9b81f9bd
"vscode:/vscode.git/clone" did not exist on "078df46bc9a99178a9a744b872899990353769a4"
Unverified
Commit
9b81f9bd
authored
Mar 18, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Mar 17, 2025
Browse files
sglang quant module remove vllm dependency (#4507)
parent
f81a27f6
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
913 additions
and
244 deletions
+913
-244
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+180
-123
python/sglang/srt/layers/quantization/blockwise_int8.py
python/sglang/srt/layers/quantization/blockwise_int8.py
+1
-1
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+64
-27
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+95
-83
python/sglang/srt/layers/quantization/gptq.py
python/sglang/srt/layers/quantization/gptq.py
+24
-3
python/sglang/srt/layers/quantization/kv_cache.py
python/sglang/srt/layers/quantization/kv_cache.py
+98
-0
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+9
-7
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+442
-0
No files found.
python/sglang/srt/layers/quantization/__init__.py
View file @
9b81f9bd
...
@@ -6,21 +6,41 @@ from copy import deepcopy
...
@@ -6,21 +6,41 @@ from copy import deepcopy
from
typing
import
Callable
,
Dict
,
Optional
,
Type
,
Union
from
typing
import
Callable
,
Dict
,
Optional
,
Type
,
Union
import
torch
import
torch
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
try
:
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
CompressedTensorsConfig
,
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
CompressedTensorsConfig
,
from
vllm.model_executor.layers.quantization.experts_int8
import
ExpertsInt8Config
)
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
from
vllm.model_executor.layers.quantization.experts_int8
import
ExpertsInt8Config
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
GPTQMarlin24Config
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
GPTQMarlin24Config
,
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
VLLM_AVAILABLE
=
True
except
ImportError
:
VLLM_AVAILABLE
=
False
# Define empty classes as placeholders when vllm is not available
class
DummyConfig
:
pass
AQLMConfig
=
AWQConfig
=
AWQMarlinConfig
=
BitsAndBytesConfig
=
(
CompressedTensorsConfig
)
=
DummyConfig
DeepSpeedFPConfig
=
ExpertsInt8Config
=
FBGEMMFp8Config
=
GGUFConfig
=
(
GPTQMarlin24Config
)
=
DummyConfig
MarlinConfig
=
QQQConfig
=
Int8TpuConfig
=
DummyConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.blockwise_int8
import
BlockInt8Config
from
sglang.srt.layers.quantization.blockwise_int8
import
BlockInt8Config
...
@@ -30,29 +50,37 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
...
@@ -30,29 +50,37 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
# Base quantization methods that don't depend on vllm
"aqlm"
:
AQLMConfig
,
BASE_QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"tpu_int8"
:
Int8TpuConfig
,
"fp8"
:
Fp8Config
,
"fp8"
:
Fp8Config
,
"blockwise_int8"
:
BlockInt8Config
,
"blockwise_int8"
:
BlockInt8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"marlin"
:
MarlinConfig
,
"modelopt"
:
ModelOptFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
"gguf"
:
GGUFConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"gptq"
:
GPTQConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"w8a8_int8"
:
W8A8Int8Config
,
"w8a8_int8"
:
W8A8Int8Config
,
"w8a8_fp8"
:
W8A8Fp8Config
,
"w8a8_fp8"
:
W8A8Fp8Config
,
}
}
# Add vllm-dependent methods if available
QUANTIZATION_METHODS
=
BASE_QUANTIZATION_METHODS
.
copy
()
if
VLLM_AVAILABLE
:
VLLM_QUANTIZATION_METHODS
=
{
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"tpu_int8"
:
Int8TpuConfig
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"marlin"
:
MarlinConfig
,
"gguf"
:
GGUFConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"awq_marlin"
:
AWQMarlinConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"experts_int8"
:
ExpertsInt8Config
,
}
QUANTIZATION_METHODS
.
update
(
VLLM_QUANTIZATION_METHODS
)
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
if
quantization
not
in
QUANTIZATION_METHODS
:
if
quantization
not
in
QUANTIZATION_METHODS
:
...
@@ -157,25 +185,31 @@ def get_linear_quant_method(
...
@@ -157,25 +185,31 @@ def get_linear_quant_method(
def
gptq_get_quant_method
(
self
,
layer
,
prefix
):
def
gptq_get_quant_method
(
self
,
layer
,
prefix
):
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
if
not
VLLM_AVAILABLE
:
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
return
None
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
try
:
)
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
if
isinstance
(
layer
,
FusedMoE
):
if
isinstance
(
layer
,
FusedMoE
):
return
GPTQMarlinMoEMethod
(
self
)
return
GPTQMarlinMoEMethod
(
self
)
if
isinstance
(
self
,
GPTQConfig
):
if
isinstance
(
self
,
GPTQConfig
):
return
get_linear_quant_method
(
return
get_linear_quant_method
(
self
,
layer
,
prefix
=
prefix
,
linear_method_cls
=
GPTQLinearMethod
self
,
layer
,
prefix
=
prefix
,
linear_method_cls
=
GPTQLinearMethod
)
)
elif
isinstance
(
self
,
GPTQMarlinConfig
):
elif
isinstance
(
self
,
GPTQMarlinConfig
):
return
get_linear_quant_method
(
return
get_linear_quant_method
(
self
,
layer
,
prefix
=
prefix
,
linear_method_cls
=
GPTQMarlinLinearMethod
self
,
layer
,
prefix
=
prefix
,
linear_method_cls
=
GPTQMarlinLinearMethod
)
)
except
ImportError
:
pass
return
None
return
None
...
@@ -187,33 +221,40 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
...
@@ -187,33 +221,40 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
can recognize sglang layers
can recognize sglang layers
"""
"""
if
not
VLLM_AVAILABLE
:
return
if
reverse
:
if
reverse
:
builtins
.
isinstance
=
original_isinstance
builtins
.
isinstance
=
original_isinstance
return
return
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
try
:
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.linear
import
LinearBase
VocabParallelEmbedding
,
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
)
VocabParallelEmbedding
,
)
from
sglang.srt.layers.linear
import
LinearBase
as
PatchedLinearBase
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
as
PatchedFusedMoE
from
sglang.srt.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
as
PatchedVocabParallelEmbedding
,
)
def
patched_isinstance
(
obj
,
classinfo
):
from
sglang.srt.layers.linear
import
LinearBase
as
PatchedLinearBase
if
classinfo
is
LinearBase
:
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
return
original_isinstance
(
obj
,
PatchedLinearBase
)
FusedMoE
as
PatchedFusedMoE
,
if
classinfo
is
FusedMoE
:
)
return
original_isinstance
(
obj
,
PatchedFusedMoE
)
from
sglang.srt.layers.vocab_parallel_embedding
import
(
if
classinfo
is
VocabParallelEmbedding
:
VocabParallelEmbedding
as
PatchedVocabParallelEmbedding
,
return
original_isinstance
(
obj
,
PatchedVocabParallelEmbedding
)
)
return
original_isinstance
(
obj
,
classinfo
)
builtins
.
isinstance
=
patched_isinstance
def
patched_isinstance
(
obj
,
classinfo
):
if
classinfo
is
LinearBase
:
return
original_isinstance
(
obj
,
PatchedLinearBase
)
if
classinfo
is
FusedMoE
:
return
original_isinstance
(
obj
,
PatchedFusedMoE
)
if
classinfo
is
VocabParallelEmbedding
:
return
original_isinstance
(
obj
,
PatchedVocabParallelEmbedding
)
return
original_isinstance
(
obj
,
classinfo
)
builtins
.
isinstance
=
patched_isinstance
except
ImportError
:
return
def
monkey_patch_moe_apply
(
class_obj
:
"FusedMoEMethodBase"
):
def
monkey_patch_moe_apply
(
class_obj
:
"FusedMoEMethodBase"
):
...
@@ -221,72 +262,88 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
...
@@ -221,72 +262,88 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert sglang arguments to vllm arguments.
Convert sglang arguments to vllm arguments.
"""
"""
original_apply
=
class_obj
.
apply
if
not
VLLM_AVAILABLE
:
sig
=
inspect
.
signature
(
original_apply
)
return
param_names
=
list
(
sig
.
parameters
.
keys
())
has_correction_bias
=
"e_score_correction_bias"
in
param_names
try
:
original_apply
=
class_obj
.
apply
def
new_apply
(
sig
=
inspect
.
signature
(
original_apply
)
self
,
param_names
=
list
(
sig
.
parameters
.
keys
())
layer
:
torch
.
nn
.
Module
,
has_correction_bias
=
"e_score_correction_bias"
in
param_names
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
def
new_apply
(
top_k
:
int
,
self
,
renormalize
:
bool
,
layer
:
torch
.
nn
.
Module
,
use_grouped_topk
:
bool
,
x
:
torch
.
Tensor
,
topk_group
:
Optional
[
int
]
=
None
,
router_logits
:
torch
.
Tensor
,
num_expert_group
:
Optional
[
int
]
=
None
,
top_k
:
int
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
renormalize
:
bool
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_grouped_topk
:
bool
,
activation
:
str
=
"silu"
,
topk_group
:
Optional
[
int
]
=
None
,
inplace
:
bool
=
True
,
num_expert_group
:
Optional
[
int
]
=
None
,
no_combine
:
bool
=
False
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
):
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
assert
activation
==
"silu"
activation
:
str
=
"silu"
,
assert
inplace
and
not
no_combine
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
kwargs
=
{
):
"self"
:
self
,
assert
activation
==
"silu"
"layer"
:
layer
,
assert
inplace
and
not
no_combine
"x"
:
x
,
"router_logits"
:
router_logits
,
kwargs
=
{
"top_k"
:
top_k
,
"self"
:
self
,
"renormalize"
:
renormalize
,
"layer"
:
layer
,
"use_grouped_topk"
:
use_grouped_topk
,
"x"
:
x
,
"topk_group"
:
topk_group
,
"router_logits"
:
router_logits
,
"num_expert_group"
:
num_expert_group
,
"top_k"
:
top_k
,
"custom_routing_function"
:
custom_routing_function
,
"renormalize"
:
renormalize
,
}
"use_grouped_topk"
:
use_grouped_topk
,
if
correction_bias
is
not
None
:
"topk_group"
:
topk_group
,
if
not
has_correction_bias
:
"num_expert_group"
:
num_expert_group
,
raise
ValueError
(
"custom_routing_function"
:
custom_routing_function
,
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
}
)
if
correction_bias
is
not
None
:
kwargs
[
"e_score_correction_bias"
]
=
correction_bias
if
not
has_correction_bias
:
return
original_apply
(
**
kwargs
)
raise
ValueError
(
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
setattr
(
class_obj
,
"apply"
,
new_apply
)
)
kwargs
[
"e_score_correction_bias"
]
=
correction_bias
return
original_apply
(
**
kwargs
)
setattr
(
class_obj
,
"apply"
,
new_apply
)
except
(
ImportError
,
AttributeError
):
return
def
monkey_patch_quant_configs
():
def
monkey_patch_quant_configs
():
"""Apply all monkey patches in one place."""
"""Apply all monkey patches in one place."""
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMoEMethod
if
not
VLLM_AVAILABLE
:
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
return
CompressedTensorsW8A8Fp8MoEMethod
,
CompressedTensorsWNA16MoEMethod
,
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
GPTQMarlinMoEMethod
setattr
(
GPTQMarlinConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
try
:
setattr
(
GPTQConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMoEMethod
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
CompressedTensorsW8A8Fp8MoEMethod
,
CompressedTensorsWNA16MoEMethod
,
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinMoEMethod
,
)
setattr
(
GPTQMarlinConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
setattr
(
GPTQConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
monkey_patch_moe_apply
(
AWQMoEMethod
)
monkey_patch_moe_apply
(
AWQMoEMethod
)
monkey_patch_moe_apply
(
GPTQMarlinMoEMethod
)
monkey_patch_moe_apply
(
GPTQMarlinMoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsW8A8Fp8MoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsW8A8Fp8MoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsWNA16MoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsWNA16MoEMethod
)
except
ImportError
:
return
monkey_patch_quant_configs
()
# Only apply monkey patches if vllm is available
if
VLLM_AVAILABLE
:
monkey_patch_quant_configs
()
__all__
=
[
__all__
=
[
...
...
python/sglang/srt/layers/quantization/blockwise_int8.py
View file @
9b81f9bd
...
@@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional
...
@@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional
import
torch
import
torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
is_layer_skipped
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
...
@@ -19,6 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -19,6 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.layers.quantization.int8_utils
import
apply_w8a8_block_int8_linear
from
sglang.srt.layers.quantization.int8_utils
import
apply_w8a8_block_int8_linear
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
9b81f9bd
...
@@ -7,20 +7,33 @@ import torch
...
@@ -7,20 +7,33 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
sglang.srt.layers.quantization.utils
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
is_layer_skipped
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
all_close_1d
,
convert_to_channelwise
,
convert_to_channelwise
,
is_layer_skipped
,
per_tensor_dequantize
,
per_tensor_dequantize
,
requantize_with_max_scale
,
requantize_with_max_scale
,
)
)
try
:
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
)
MARLIN_FP8_AVAILABLE
=
True
except
ImportError
:
MARLIN_FP8_AVAILABLE
=
False
def
apply_fp8_marlin_linear
(
*
args
,
**
kwargs
):
raise
ImportError
(
"vllm is not installed"
)
def
prepare_fp8_layer_for_marlin
(
*
args
,
**
kwargs
):
raise
ImportError
(
"vllm is not installed"
)
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
LinearBase
,
LinearBase
,
...
@@ -46,6 +59,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
...
@@ -46,6 +59,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
)
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_bool_env_var
,
is_cuda
,
is_hip
,
is_hip
,
permute_weight
,
permute_weight
,
print_warning_once
,
print_warning_once
,
...
@@ -60,6 +74,13 @@ if _is_hip:
...
@@ -60,6 +74,13 @@ if _is_hip:
from
aiter.fused_moe_bf16_asm
import
asm_moe
from
aiter.fused_moe_bf16_asm
import
asm_moe
from
aiter.ops.shuffle
import
shuffle_weight
from
aiter.ops.shuffle
import
shuffle_weight
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
else
:
from
vllm
import
_custom_ops
as
vllm_ops
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -173,7 +194,9 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -173,7 +194,9 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
self
.
use_marlin
=
(
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
and
MARLIN_FP8_AVAILABLE
)
# Disable marlin for ROCm
# Disable marlin for ROCm
if
_is_hip
:
if
_is_hip
:
self
.
use_marlin
=
False
self
.
use_marlin
=
False
...
@@ -371,9 +394,12 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -371,9 +394,12 @@ class Fp8LinearMethod(LinearMethodBase):
)
)
if
self
.
use_marlin
:
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
try
:
# Activations not quantized for marlin.
prepare_fp8_layer_for_marlin
(
layer
)
del
layer
.
input_scale
# Activations not quantized for marlin.
del
layer
.
input_scale
except
ImportError
:
self
.
use_marlin
=
False
def
apply
(
def
apply
(
self
,
self
,
...
@@ -383,15 +409,18 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -383,15 +409,18 @@ class Fp8LinearMethod(LinearMethodBase):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
use_marlin
:
if
self
.
use_marlin
:
return
apply_fp8_marlin_linear
(
try
:
input
=
x
,
return
apply_fp8_marlin_linear
(
weight
=
layer
.
weight
,
input
=
x
,
weight_scale
=
layer
.
weight_scale
,
weight
=
layer
.
weight
,
workspace
=
layer
.
workspace
,
weight_scale
=
layer
.
weight_scale
,
size_n
=
layer
.
output_size_per_partition
,
workspace
=
layer
.
workspace
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
bias
=
bias
,
size_k
=
layer
.
input_size_per_partition
,
)
bias
=
bias
,
)
except
ImportError
:
self
.
use_marlin
=
False
if
self
.
block_quant
:
if
self
.
block_quant
:
return
apply_w8a8_block_fp8_linear
(
return
apply_w8a8_block_fp8_linear
(
...
@@ -680,12 +709,20 @@ class Fp8MoEMethod:
...
@@ -680,12 +709,20 @@ class Fp8MoEMethod:
requires_grad
=
False
,
requires_grad
=
False
,
)
)
for
expert
in
range
(
layer
.
num_experts
):
for
expert
in
range
(
layer
.
num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
if
_is_cuda
:
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
)
sgl_scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
)
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
)
sgl_scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
else
:
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
9b81f9bd
...
@@ -28,7 +28,12 @@ if _is_cuda:
...
@@ -28,7 +28,12 @@ if _is_cuda:
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_quant_fp8
if
use_vllm_cutlass_w8a8_fp8_kernel
:
if
use_vllm_cutlass_w8a8_fp8_kernel
:
from
vllm
import
_custom_ops
as
ops
try
:
from
vllm
import
_custom_ops
as
ops
VLLM_AVAILABLE
=
True
except
ImportError
:
VLLM_AVAILABLE
=
False
else
:
else
:
from
sgl_kernel
import
fp8_scaled_mm
from
sgl_kernel
import
fp8_scaled_mm
...
@@ -219,90 +224,97 @@ def apply_fp8_linear(
...
@@ -219,90 +224,97 @@ def apply_fp8_linear(
)
)
if
cutlass_fp8_supported
:
if
cutlass_fp8_supported
:
if
use_vllm_cutlass_w8a8_fp8_kernel
:
try
:
# Fall back to vllm cutlass w8a8 fp8 kernel
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
output
=
ops
.
cutlass_scaled_mm
(
# Fall back to vllm cutlass w8a8 fp8 kernel
qinput
,
output
=
ops
.
cutlass_scaled_mm
(
weight
,
qinput
,
out_dtype
=
input
.
dtype
,
weight
,
scale_a
=
x_scale
,
out_dtype
=
input
.
dtype
,
scale_b
=
weight_scale
,
scale_a
=
x_scale
,
bias
=
bias
,
scale_b
=
weight_scale
,
)
bias
=
bias
,
else
:
)
assert
(
else
:
weight_scale
.
numel
()
==
weight
.
shape
[
1
]
assert
(
),
"cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
weight_scale
.
numel
()
==
weight
.
shape
[
1
]
output
=
fp8_scaled_mm
(
),
"cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
qinput
,
weight
,
x_scale
,
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
output
=
fp8_scaled_mm
(
)
qinput
,
return
output
.
view
(
*
output_shape
)
weight
,
x_scale
,
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
,
)
return
output
.
view
(
*
output_shape
)
except
(
ImportError
,
NameError
,
AttributeError
):
pass
# torch.scaled_mm supports per tensor weights + activations only
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
# so fallback to naive if per channel or per token
else
:
per_tensor_weights
=
weight_scale
.
numel
()
==
1
per_tensor_weights
=
weight_scale
.
numel
()
==
1
per_tensor_activations
=
x_scale
.
numel
()
==
1
per_tensor_activations
=
x_scale
.
numel
()
==
1
if
per_tensor_weights
and
per_tensor_activations
:
if
per_tensor_weights
and
per_tensor_activations
:
# Fused GEMM_DQ
# Fused GEMM_DQ
output
=
torch
.
_scaled_mm
(
output
=
torch
.
_scaled_mm
(
qinput
,
qinput
,
weight
,
weight
,
out_dtype
=
input
.
dtype
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
scale_b
=
weight_scale
,
bias
=
bias
,
bias
=
bias
,
)
)
# A fix for discrepancy in scaled_mm which returns tuple
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
output
=
output
[
0
]
return
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
]).
view
(
*
output_shape
)
return
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
]).
view
(
*
output_shape
)
else
:
else
:
# Fallback for channelwise case, where we use unfused DQ
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
# before applying a GEMM.
#
#
# In order to compute quantized operands, a quantized kernel
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
# C = s_w * s_x * (X * W) + bias
#
#
# For the scaled_mm fallback case, we break this down, since it
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# does not support s_w being a vector.
# Making sure the dummy tensor is on the same device as the weight
# Making sure the dummy tensor is on the same device as the weight
global
TORCH_DEVICE_IDENTITY
global
TORCH_DEVICE_IDENTITY
if
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
:
if
TORCH_DEVICE_IDENTITY
.
device
!=
weight
.
device
:
TORCH_DEVICE_IDENTITY
=
TORCH_DEVICE_IDENTITY
.
to
(
weight
.
device
)
TORCH_DEVICE_IDENTITY
=
TORCH_DEVICE_IDENTITY
.
to
(
weight
.
device
)
# GEMM
# GEMM
# This computes C = (X * W).
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
# Output in fp32 to allow subsequent ops to happen in-place
output
=
torch
.
_scaled_mm
(
output
=
torch
.
_scaled_mm
(
qinput
,
qinput
,
weight
,
weight
,
scale_a
=
TORCH_DEVICE_IDENTITY
,
scale_a
=
TORCH_DEVICE_IDENTITY
,
scale_b
=
TORCH_DEVICE_IDENTITY
,
scale_b
=
TORCH_DEVICE_IDENTITY
,
out_dtype
=
torch
.
float32
,
out_dtype
=
torch
.
float32
,
)
)
# A fix for discrepancy in scaled_mm which returns tuple
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
output
=
output
[
0
]
output
=
output
[
0
]
# Unpad (undo num_token_padding)
# Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input_2d
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input_2d
.
shape
[
0
])
# DQ
# DQ
# C = sw * sx * (X * W) + bias
# C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
weight_scale
.
t
()
output
=
output
*
x_scale
*
weight_scale
.
t
()
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
python/sglang/srt/layers/quantization/gptq.py
View file @
9b81f9bd
...
@@ -3,11 +3,21 @@ from fractions import Fraction
...
@@ -3,11 +3,21 @@ from fractions import Fraction
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
vllm.scalar_type
import
scalar_types
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.utils
import
scalar_types
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
try
:
import
vllm
VLLM_AVAILABLE
=
True
except
ImportError
:
VLLM_AVAILABLE
=
False
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -110,6 +120,9 @@ class GPTQConfig(QuantizationConfig):
...
@@ -110,6 +120,9 @@ class GPTQConfig(QuantizationConfig):
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"GPTQLinearMethod"
]:
)
->
Optional
[
"GPTQLinearMethod"
]:
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"vllm is not installed"
)
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
sglang.srt.layers.quantization
import
get_linear_quant_method
from
sglang.srt.layers.quantization
import
get_linear_quant_method
...
@@ -263,6 +276,9 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -263,6 +276,9 @@ class GPTQMarlinConfig(QuantizationConfig):
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"vllm is not installed"
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
,
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
GPTQMarlinMoEMethod
,
...
@@ -285,6 +301,9 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -285,6 +301,9 @@ class GPTQMarlinConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
is_gptq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
def
is_gptq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
if
not
VLLM_AVAILABLE
:
return
False
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
)
num_bits
=
quant_config
.
get
(
"bits"
)
group_size
=
quant_config
.
get
(
"group_size"
)
group_size
=
quant_config
.
get
(
"group_size"
)
...
@@ -294,9 +313,8 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -294,9 +313,8 @@ class GPTQMarlinConfig(QuantizationConfig):
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supported
,
check_marlin_supported
,
)
)
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_cuda
()
:
if
not
_
is_cuda
:
return
False
return
False
if
quant_method
!=
"gptq"
:
if
quant_method
!=
"gptq"
:
...
@@ -407,6 +425,9 @@ class MarlinConfig(QuantizationConfig):
...
@@ -407,6 +425,9 @@ class MarlinConfig(QuantizationConfig):
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"MarlinLinearMethod"
]:
)
->
Optional
[
"MarlinLinearMethod"
]:
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"vllm is not installed"
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinLinearMethod
from
vllm.model_executor.layers.quantization.marlin
import
MarlinLinearMethod
if
isinstance
(
layer
,
LinearBase
)
or
(
if
isinstance
(
layer
,
LinearBase
)
or
(
...
...
python/sglang/srt/layers/quantization/kv_cache.py
0 → 100644
View file @
9b81f9bd
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/kv_cache.py
import
logging
import
torch
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.utils
import
is_hip
_is_hip
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
class
BaseKVCacheMethod
(
QuantizeMethodBase
):
"""
Quant method that adds `_k_scale` and `_v_scale` attributes to the
Attention layer to support loading those scaling factors from checkpoints.
The k/v_scale will be used to:
- quantize k/v_cache entries before saving them to the cache
- dequantize k/v_cache entries before fetching them from the cache
:param quant_config: the appropriate QuantizationConfig
"""
def
__init__
(
self
,
quant_config
:
QuantizationConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
):
"""
Create "weight" (aka k_scale and v_scale) for an attention layer.
"""
# Initialize the KV cache scales to -1.0, which is an invalid value.
# If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer
.
k_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
-
1.0
),
requires_grad
=
False
)
layer
.
v_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
-
1.0
),
requires_grad
=
False
)
@
classmethod
def
is_fp8_fnuz
(
cls
)
->
bool
:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return
"gfx94"
in
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
raise
RuntimeError
(
f
"
{
self
.
__class__
.
__name__
}
.apply should not be called."
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to
# calculate them on the fly.
if
layer
.
kv_cache_dtype
!=
"auto"
and
not
layer
.
calculate_kv_scales
:
if
layer
.
k_scale
>
0.0
and
layer
.
v_scale
>
0.0
:
# We prefer to use separate k_scale and v_scale if present
k_scale
=
layer
.
k_scale
.
to
(
"cpu"
).
tolist
()
v_scale
=
layer
.
v_scale
.
to
(
"cpu"
).
tolist
()
if
_is_hip
and
self
.
is_fp8_fnuz
():
k_scale
*=
2
v_scale
*=
2
elif
layer
.
k_scale
<
0.0
and
layer
.
v_scale
<
0.0
:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale
=
1.0
v_scale
=
1.0
else
:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert
layer
.
k_scale
>
0.0
scale_to_duplicate
=
max
(
layer
.
k_scale
,
layer
.
v_scale
)
k_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
v_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
if
_is_hip
and
self
.
is_fp8_fnuz
():
k_scale
*=
2
v_scale
*=
2
if
not
isinstance
(
k_scale
,
float
)
or
not
isinstance
(
v_scale
,
float
):
raise
ValueError
(
"Only support per-tensor scaling factor "
"for fp8 KV cache"
)
# These are used in the final Attention.forward()
layer
.
_k_scale
.
copy_
(
k_scale
)
layer
.
_v_scale
.
copy_
(
v_scale
)
layer
.
_k_scale_float
=
k_scale
layer
.
_v_scale_float
=
v_scale
if
k_scale
==
1.0
and
v_scale
==
1.0
and
"e5m2"
not
in
layer
.
kv_cache_dtype
:
logger
.
warning
(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint."
)
del
layer
.
k_scale
del
layer
.
v_scale
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
9b81f9bd
...
@@ -5,12 +5,6 @@ from typing import Any, Dict, List, Optional
...
@@ -5,12 +5,6 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
convert_to_channelwise
,
cutlass_fp8_supported
,
requantize_with_max_scale
,
)
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
...
@@ -19,7 +13,15 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -19,7 +13,15 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.layers.quantization.fp8_utils
import
apply_fp8_linear
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
)
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.utils
import
(
convert_to_channelwise
,
requantize_with_max_scale
,
)
# Initialize logger for the module
# Initialize logger for the module
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/layers/quantization/utils.py
0 → 100644
View file @
9b81f9bd
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/scalar_type.py
import
functools
import
struct
from
dataclasses
import
dataclass
from
enum
import
Enum
from
types
import
MappingProxyType
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
,
Union
import
torch
def
is_layer_skipped
(
prefix
:
str
,
ignored_layers
:
List
[
str
],
fused_mapping
:
Mapping
[
str
,
List
[
str
]]
=
MappingProxyType
({}),
)
->
bool
:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name
=
prefix
.
split
(
"."
)[
-
1
]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if
proj_name
in
fused_mapping
:
shard_prefixes
=
[
prefix
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
fused_mapping
[
proj_name
]
]
is_skipped
=
None
for
shard_prefix
in
shard_prefixes
:
is_shard_skipped
=
shard_prefix
in
ignored_layers
if
is_skipped
is
None
:
is_skipped
=
is_shard_skipped
elif
is_shard_skipped
!=
is_skipped
:
raise
ValueError
(
f
"Detected some but not all shards of
{
prefix
}
"
"are quantized. All shards of fused layers "
"to have the same precision."
)
else
:
is_skipped
=
prefix
in
ignored_layers
assert
is_skipped
is
not
None
return
is_skipped
def
per_tensor_dequantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
]
)
->
torch
.
Tensor
:
fake_qweight
=
tensor
.
to
(
torch
.
float16
)
dq_weight
=
fake_qweight
*
inv_scale
return
dq_weight
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
def
convert_to_channelwise
(
weight_scale
:
torch
.
Tensor
,
logical_widths
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Create channelwise buffer
weight_scale_channel
=
torch
.
empty
(
(
sum
(
logical_widths
),
1
),
dtype
=
torch
.
float32
,
device
=
weight_scale
.
device
)
# Expand each scale to match the size of each logical matrix.
start
=
0
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
end
=
start
+
logical_width
weight_scale_channel
[
start
:
end
,
:]
=
weight_scale
[
idx
]
start
=
end
return
weight_scale_channel
def
requantize_with_max_scale
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
logical_widths
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Max scale to be used for requanitzation.
max_w_scale
=
weight_scale
.
max
()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint
=
(
weight_scale
[
-
1
]
>
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
)
# If unfused checkpoint, need requanize with the single scale.
if
unfused_module_in_checkpoint
:
start
=
0
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
end
=
start
+
logical_width
weight_dq
=
per_tensor_dequantize
(
weight
[
start
:
end
,
:],
weight_scale
[
idx
])
weight
[
start
:
end
,
:],
_
=
ops
.
scaled_fp8_quant
(
weight_dq
,
max_w_scale
)
start
=
end
return
max_w_scale
,
weight
# Mirrors enum in `core/scalar_type.hpp`
class
NanRepr
(
Enum
):
NONE
=
0
# nans are not supported
IEEE_754
=
1
# nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN
=
2
# nans are: Exp all 1s, mantissa all 1s
# This ScalarType class is a parallel implementation of the C++ ScalarType
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
# in sync until the inductor fully supports custom C++ classes.
@
dataclass
(
frozen
=
True
)
class
ScalarType
:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent
:
int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa
:
int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
signed
:
bool
"If the type is signed (i.e. has a sign bit)"
bias
:
int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
_finite_values_only
:
bool
=
False
"""
Private: if infs are supported, used `has_infs()` instead.
"""
nan_repr
:
NanRepr
=
NanRepr
.
IEEE_754
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
def
_floating_point_max_int
(
self
)
->
int
:
assert
(
self
.
mantissa
<=
52
and
self
.
exponent
<=
11
),
f
"Cannot represent max/min as a double for type
{
self
.
__str__
()
}
"
max_mantissa
=
(
1
<<
self
.
mantissa
)
-
1
if
self
.
nan_repr
==
NanRepr
.
EXTD_RANGE_MAX_MIN
:
max_mantissa
=
max_mantissa
-
1
max_exponent
=
(
1
<<
self
.
exponent
)
-
2
if
self
.
nan_repr
==
NanRepr
.
EXTD_RANGE_MAX_MIN
or
self
.
nan_repr
==
NanRepr
.
NONE
:
assert
(
self
.
exponent
<
11
),
f
"Cannot represent max/min as a double for type
{
self
.
__str__
()
}
"
max_exponent
=
max_exponent
+
1
# adjust the exponent to match that of a double
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
# e is the exponent bits), there is some precedent for non-standard
# biases, example `float8_e4m3b11fnuz` here:
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
# complication we are just assuming the standard exponent bias until
# there is a need to support non-standard biases
exponent_bias
=
(
1
<<
(
self
.
exponent
-
1
))
-
1
exponent_bias_double
=
(
1
<<
10
)
-
1
# double e = 11
max_exponent_double
=
max_exponent
-
exponent_bias
+
exponent_bias_double
# shift the mantissa and exponent into the proper positions for an
# IEEE double and bitwise-or them together.
return
(
max_mantissa
<<
(
52
-
self
.
mantissa
))
|
(
max_exponent_double
<<
52
)
def
_floating_point_max
(
self
)
->
float
:
double_raw
=
self
.
_floating_point_max_int
()
return
struct
.
unpack
(
"!d"
,
struct
.
pack
(
"!Q"
,
double_raw
))[
0
]
def
_raw_max
(
self
)
->
Union
[
int
,
float
]:
if
self
.
is_floating_point
():
return
self
.
_floating_point_max
()
else
:
assert
(
self
.
size_bits
<
64
or
self
.
size_bits
==
64
and
self
.
is_signed
()
),
"Cannot represent max as an int"
return
(
1
<<
self
.
mantissa
)
-
1
def
_raw_min
(
self
)
->
Union
[
int
,
float
]:
if
self
.
is_floating_point
():
assert
(
self
.
is_signed
()
),
"We currently assume all floating point types are signed"
sign_bit_double
=
1
<<
63
max_raw
=
self
.
_floating_point_max_int
()
min_raw
=
max_raw
|
sign_bit_double
return
struct
.
unpack
(
"!d"
,
struct
.
pack
(
"!Q"
,
min_raw
))[
0
]
else
:
assert
(
not
self
.
is_signed
()
or
self
.
size_bits
<=
64
),
"Cannot represent min as a int64_t"
if
self
.
is_signed
():
return
-
(
1
<<
(
self
.
size_bits
-
1
))
else
:
return
0
@
functools
.
cached_property
def
id
(
self
)
->
int
:
"""
Convert the ScalarType to an int which can be passed to pytorch custom
ops. This layout of the int must be kept in sync with the C++
ScalarType's from_id method.
"""
val
=
0
offset
=
0
def
or_and_advance
(
member
,
bit_width
):
nonlocal
val
nonlocal
offset
bit_mask
=
(
1
<<
bit_width
)
-
1
val
=
val
|
(
int
(
member
)
&
bit_mask
)
<<
offset
offset
=
offset
+
bit_width
or_and_advance
(
self
.
exponent
,
8
)
or_and_advance
(
self
.
mantissa
,
8
)
or_and_advance
(
self
.
signed
,
1
)
or_and_advance
(
self
.
bias
,
32
)
or_and_advance
(
self
.
_finite_values_only
,
1
)
or_and_advance
(
self
.
nan_repr
.
value
,
8
)
assert
offset
<=
64
,
f
"ScalarType fields too big
{
offset
}
to fit into an int64"
return
val
@
property
def
size_bits
(
self
)
->
int
:
return
self
.
exponent
+
self
.
mantissa
+
int
(
self
.
signed
)
def
min
(
self
)
->
Union
[
int
,
float
]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
return
self
.
_raw_min
()
-
self
.
bias
def
max
(
self
)
->
Union
[
int
,
float
]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
return
self
.
_raw_max
()
-
self
.
bias
def
is_signed
(
self
)
->
bool
:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
return
self
.
signed
def
is_floating_point
(
self
)
->
bool
:
"If the type is a floating point type"
return
self
.
exponent
!=
0
def
is_integer
(
self
)
->
bool
:
"If the type is an integer type"
return
self
.
exponent
==
0
def
has_bias
(
self
)
->
bool
:
"If the type has a non-zero bias"
return
self
.
bias
!=
0
def
has_infs
(
self
)
->
bool
:
"If the type is floating point and supports infinity"
return
not
self
.
_finite_values_only
def
has_nans
(
self
)
->
bool
:
return
self
.
nan_repr
!=
NanRepr
.
NONE
.
value
def
is_ieee_754
(
self
)
->
bool
:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return
self
.
nan_repr
==
NanRepr
.
IEEE_754
.
value
and
not
self
.
_finite_values_only
def
__str__
(
self
)
->
str
:
"""
naming generally follows: https://github.com/jax-ml/ml_dtypes
for floating point types (leading f) the scheme is:
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
flags:
- no-flags: means it follows IEEE 754 conventions
- f: means finite values only (no infinities)
- n: means nans are supported (non-standard encoding)
for integer types the scheme is:
`[u]int<size_bits>[b<bias>]`
- if bias is not present it means its zero
"""
if
self
.
is_floating_point
():
ret
=
(
"float"
+
str
(
self
.
size_bits
)
+
"_e"
+
str
(
self
.
exponent
)
+
"m"
+
str
(
self
.
mantissa
)
)
if
not
self
.
is_ieee_754
():
if
self
.
_finite_values_only
:
ret
=
ret
+
"f"
if
self
.
nan_repr
!=
NanRepr
.
NONE
:
ret
=
ret
+
"n"
return
ret
else
:
ret
=
(
"int"
if
self
.
is_signed
()
else
"uint"
)
+
str
(
self
.
size_bits
)
if
self
.
has_bias
():
ret
=
ret
+
"b"
+
str
(
self
.
bias
)
return
ret
def
__repr__
(
self
)
->
str
:
return
"ScalarType."
+
self
.
__str__
()
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def
__len__
(
self
)
->
int
:
raise
TypeError
#
# Convenience Constructors
#
@
classmethod
def
int_
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
"ScalarType"
:
"Create a signed integer scalar type (size_bits includes sign-bit)."
ret
=
cls
(
0
,
size_bits
-
1
,
True
,
bias
if
bias
else
0
)
ret
.
id
# noqa B018: make sure the id is cached
return
ret
@
classmethod
def
uint
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
"ScalarType"
:
"""Create a unsigned integer scalar type."""
ret
=
cls
(
0
,
size_bits
,
False
,
bias
if
bias
else
0
)
ret
.
id
# noqa B018: make sure the id is cached
return
ret
@
classmethod
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
"ScalarType"
:
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
assert
mantissa
>
0
and
exponent
>
0
ret
=
cls
(
exponent
,
mantissa
,
True
,
0
)
ret
.
id
# noqa B018: make sure the id is cached
return
ret
@
classmethod
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
nan_repr
:
NanRepr
)
->
"ScalarType"
:
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
assert
mantissa
>
0
and
exponent
>
0
assert
nan_repr
!=
NanRepr
.
IEEE_754
,
(
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions"
)
ret
=
cls
(
exponent
,
mantissa
,
True
,
0
,
finite_values_only
,
nan_repr
)
ret
.
id
# noqa B018: make sure the id is cached
return
ret
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
# flags:
# - no-flags: means it follows IEEE 754 conventions
# - f: means finite values only (no infinities)
# - n: means nans are supported (non-standard encoding)
# for integer types the scheme is:
# `[u]int<size_bits>[b<bias>]`
# - if bias is not present it means its zero
class
scalar_types
:
int4
=
ScalarType
.
int_
(
4
,
None
)
uint4
=
ScalarType
.
uint
(
4
,
None
)
int8
=
ScalarType
.
int_
(
8
,
None
)
uint8
=
ScalarType
.
uint
(
8
,
None
)
float8_e4m3fn
=
ScalarType
.
float_
(
4
,
3
,
True
,
NanRepr
.
EXTD_RANGE_MAX_MIN
)
float8_e5m2
=
ScalarType
.
float_IEEE754
(
5
,
2
)
float16_e8m7
=
ScalarType
.
float_IEEE754
(
8
,
7
)
float16_e5m10
=
ScalarType
.
float_IEEE754
(
5
,
10
)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f
=
ScalarType
.
float_
(
3
,
2
,
True
,
NanRepr
.
NONE
)
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
float4_e2m1fn
=
ScalarType
.
float_
(
2
,
1
,
True
,
NanRepr
.
NONE
)
# "gptq" types
uint2b2
=
ScalarType
.
uint
(
2
,
2
)
uint3b4
=
ScalarType
.
uint
(
3
,
4
)
uint4b8
=
ScalarType
.
uint
(
4
,
8
)
uint8b128
=
ScalarType
.
uint
(
8
,
128
)
# colloquial names
bfloat16
=
float16_e8m7
float16
=
float16_e5m10
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment