Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c66b2c9c
Unverified
Commit
c66b2c9c
authored
Feb 21, 2025
by
Zhiyu
Committed by
GitHub
Feb 22, 2025
Browse files
Add support for nvidia modelopt fp8 kv cache (#3223)
parent
20b765a2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
2 deletions
+65
-2
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+18
-1
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+12
-1
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+6
-0
test/srt/test_modelopt_fp8kvcache.py
test/srt/test_modelopt_fp8kvcache.py
+29
-0
No files found.
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
c66b2c9c
...
@@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional
...
@@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
apply_fp8_linear
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
requantize_with_max_scale
,
requantize_with_max_scale
,
)
)
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
...
@@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig):
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
return
ModelOptFp8LinearMethod
(
self
)
if
isinstance
(
layer
,
LinearBase
)
else
None
if
isinstance
(
layer
,
LinearBase
):
return
ModelOptFp8LinearMethod
(
self
)
if
isinstance
(
layer
,
AttentionBackend
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
return
[]
...
@@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
bias
=
bias
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
)
)
class
ModelOptFp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
super
().
__init__
(
quant_config
)
python/sglang/srt/model_loader/weight_utils.py
View file @
c66b2c9c
...
@@ -644,9 +644,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
...
@@ -644,9 +644,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return
remapped_name
return
remapped_name
possible_scale_names
=
[
".k_scale"
,
".v_scale"
]
possible_scale_names
=
[
".k_scale"
,
".v_scale"
]
modelopt_scale_names
=
[
".self_attn.k_proj.k_scale"
,
".self_attn.v_proj.v_scale"
]
for
scale_name
in
possible_scale_names
:
for
scale_name
in
possible_scale_names
:
if
name
.
endswith
(
scale_name
):
if
name
.
endswith
(
scale_name
):
remapped_name
=
name
.
replace
(
scale_name
,
f
".attn
{
scale_name
}
"
)
# Check and remap the name based on modelopt scale names
if
any
(
modelopt_scale_name
in
name
for
modelopt_scale_name
in
modelopt_scale_names
):
remapped_name
=
name
.
replace
(
f
".self_attn.
{
scale_name
[
1
]
}
_proj
{
scale_name
}
"
,
f
".self_attn.attn
{
scale_name
}
"
,
)
else
:
remapped_name
=
name
.
replace
(
scale_name
,
f
".attn
{
scale_name
}
"
)
if
remapped_name
not
in
params_dict
:
if
remapped_name
not
in
params_dict
:
print_warning_once
(
print_warning_once
(
f
"Found
{
scale_name
}
in the checkpoint (e.g.
{
name
}
), "
f
"Found
{
scale_name
}
in the checkpoint (e.g.
{
name
}
), "
...
...
python/sglang/srt/models/llama.py
View file @
c66b2c9c
...
@@ -47,6 +47,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...
@@ -47,6 +47,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
kv_cache_scales_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
,
)
)
from
sglang.srt.utils
import
make_layers
from
sglang.srt.utils
import
make_layers
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -457,6 +458,11 @@ class LlamaForCausalLM(nn.Module):
...
@@ -457,6 +458,11 @@ class LlamaForCausalLM(nn.Module):
continue
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
continue
# Handle FP8 kv-scale remapping
if
"scale"
in
name
:
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
...
...
test/srt/test_modelopt_fp8kvcache.py
0 → 100644
View file @
c66b2c9c
import
unittest
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
(
ModelOptFp8Config
,
ModelOptFp8KVCacheMethod
,
)
class
TestModelOptFp8KVCacheMethod
(
unittest
.
TestCase
):
def
test_kv_cache_method_initialization
(
self
):
"""Test that ModelOptFp8KVCacheMethod can be instantiated and
inherits from BaseKVCacheMethod."""
# Create a ModelOptFp8Config object
quant_config
=
ModelOptFp8Config
(
is_checkpoint_fp8_serialized
=
True
)
# Instantiate the KV cache method
kv_cache_method
=
ModelOptFp8KVCacheMethod
(
quant_config
)
# Check inheritance
self
.
assertIsInstance
(
kv_cache_method
,
BaseKVCacheMethod
)
# Check that the quant_config is stored
self
.
assertEqual
(
kv_cache_method
.
quant_config
,
quant_config
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment