Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c66b2c9c
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "721e017401ded6d11447c4bf964291853231e0d8"
Unverified
Commit
c66b2c9c
authored
Feb 21, 2025
by
Zhiyu
Committed by
GitHub
Feb 22, 2025
Browse files
Add support for nvidia modelopt fp8 kv cache (#3223)
parent
20b765a2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
2 deletions
+65
-2
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+18
-1
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+12
-1
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+6
-0
test/srt/test_modelopt_fp8kvcache.py
test/srt/test_modelopt_fp8kvcache.py
+29
-0
No files found.
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
c66b2c9c
...
@@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional
...
@@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
apply_fp8_linear
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
requantize_with_max_scale
,
requantize_with_max_scale
,
)
)
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
...
@@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig):
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
return
ModelOptFp8LinearMethod
(
self
)
if
isinstance
(
layer
,
LinearBase
)
else
None
if
isinstance
(
layer
,
LinearBase
):
return
ModelOptFp8LinearMethod
(
self
)
if
isinstance
(
layer
,
AttentionBackend
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
return
[]
...
@@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
bias
=
bias
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
)
)
class
ModelOptFp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
super
().
__init__
(
quant_config
)
python/sglang/srt/model_loader/weight_utils.py
View file @
c66b2c9c
...
@@ -644,9 +644,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
...
@@ -644,9 +644,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return
remapped_name
return
remapped_name
possible_scale_names
=
[
".k_scale"
,
".v_scale"
]
possible_scale_names
=
[
".k_scale"
,
".v_scale"
]
modelopt_scale_names
=
[
".self_attn.k_proj.k_scale"
,
".self_attn.v_proj.v_scale"
]
for
scale_name
in
possible_scale_names
:
for
scale_name
in
possible_scale_names
:
if
name
.
endswith
(
scale_name
):
if
name
.
endswith
(
scale_name
):
remapped_name
=
name
.
replace
(
scale_name
,
f
".attn
{
scale_name
}
"
)
# Check and remap the name based on modelopt scale names
if
any
(
modelopt_scale_name
in
name
for
modelopt_scale_name
in
modelopt_scale_names
):
remapped_name
=
name
.
replace
(
f
".self_attn.
{
scale_name
[
1
]
}
_proj
{
scale_name
}
"
,
f
".self_attn.attn
{
scale_name
}
"
,
)
else
:
remapped_name
=
name
.
replace
(
scale_name
,
f
".attn
{
scale_name
}
"
)
if
remapped_name
not
in
params_dict
:
if
remapped_name
not
in
params_dict
:
print_warning_once
(
print_warning_once
(
f
"Found
{
scale_name
}
in the checkpoint (e.g.
{
name
}
), "
f
"Found
{
scale_name
}
in the checkpoint (e.g.
{
name
}
), "
...
...
python/sglang/srt/models/llama.py
View file @
c66b2c9c
...
@@ -47,6 +47,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...
@@ -47,6 +47,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
kv_cache_scales_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
,
)
)
from
sglang.srt.utils
import
make_layers
from
sglang.srt.utils
import
make_layers
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -457,6 +458,11 @@ class LlamaForCausalLM(nn.Module):
...
@@ -457,6 +458,11 @@ class LlamaForCausalLM(nn.Module):
continue
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
continue
# Handle FP8 kv-scale remapping
if
"scale"
in
name
:
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
...
...
test/srt/test_modelopt_fp8kvcache.py
0 → 100644
View file @
c66b2c9c
import
unittest
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
(
ModelOptFp8Config
,
ModelOptFp8KVCacheMethod
,
)
class
TestModelOptFp8KVCacheMethod
(
unittest
.
TestCase
):
def
test_kv_cache_method_initialization
(
self
):
"""Test that ModelOptFp8KVCacheMethod can be instantiated and
inherits from BaseKVCacheMethod."""
# Create a ModelOptFp8Config object
quant_config
=
ModelOptFp8Config
(
is_checkpoint_fp8_serialized
=
True
)
# Instantiate the KV cache method
kv_cache_method
=
ModelOptFp8KVCacheMethod
(
quant_config
)
# Check inheritance
self
.
assertIsInstance
(
kv_cache_method
,
BaseKVCacheMethod
)
# Check that the quant_config is stored
self
.
assertEqual
(
kv_cache_method
.
quant_config
,
quant_config
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment