Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d3024f4f
Unverified
Commit
d3024f4f
authored
Jan 18, 2025
by
bjmsong
Committed by
GitHub
Jan 18, 2025
Browse files
support e4m3 kvcache in qwen2 & add kv scaling facotr json (#2894)
Co-authored-by:
bjmsong
<
bjmsong@126.com
>
parent
13387e6b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
227 additions
and
9 deletions
+227
-9
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+54
-1
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+4
-2
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+34
-2
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-0
test/srt/kv_cache_scales_llama3_8b.json
test/srt/kv_cache_scales_llama3_8b.json
+42
-0
test/srt/kv_cache_scales_qwen2_1_5b.json
test/srt/kv_cache_scales_qwen2_1_5b.json
+38
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_fp8_kvcache.py
test/srt/test_fp8_kvcache.py
+53
-4
No files found.
python/sglang/srt/model_loader/weight_utils.py
View file @
d3024f4f
...
...
@@ -9,7 +9,17 @@ import logging
import
os
import
tempfile
from
collections
import
defaultdict
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
)
import
filelock
import
gguf
...
...
@@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
# If there were no matches, return the untouched param name
return
name
def
kv_cache_scales_loader
(
filename
:
str
,
tp_rank
:
int
,
tp_size
:
int
,
num_hidden_layers
:
int
,
model_type
:
Optional
[
str
],
)
->
Iterable
[
Tuple
[
int
,
float
]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
"""
try
:
with
open
(
filename
)
as
f
:
context
=
{
"model_type"
:
model_type
,
"num_hidden_layers"
:
num_hidden_layers
,
"tp_rank"
:
tp_rank
,
"tp_size"
:
tp_size
,
}
schema_dct
=
json
.
load
(
f
)
schema
=
QuantParamSchema
.
model_validate
(
schema_dct
,
context
=
context
)
layer_scales_map
=
schema
.
kv_cache
.
scaling_factor
[
tp_rank
]
return
layer_scales_map
.
items
()
except
FileNotFoundError
:
logger
.
error
(
"File or directory '%s' not found."
,
filename
)
except
json
.
JSONDecodeError
:
logger
.
error
(
"Error decoding JSON in file '%s'."
,
filename
)
except
Exception
:
logger
.
exception
(
"An error occurred while reading '%s'."
,
filename
)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger
.
warning
(
"Defaulting to KV cache scaling factors = 1.0 for all "
"layers in TP rank %d as an error occurred during loading."
,
tp_rank
,
)
return
[]
python/sglang/srt/models/llama.py
View file @
d3024f4f
...
...
@@ -23,7 +23,6 @@ import torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
kv_cache_scales_loader
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
...
...
@@ -45,7 +44,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
)
from
sglang.srt.utils
import
make_layers
from
sglang.utils
import
get_exception_traceback
...
...
python/sglang/srt/models/qwen2.py
View file @
d3024f4f
...
...
@@ -22,7 +22,10 @@ import torch
from
torch
import
nn
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
...
...
@@ -39,7 +42,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
)
from
sglang.srt.utils
import
make_layers
Qwen2Config
=
None
...
...
@@ -265,6 +271,29 @@ class Qwen2Model(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
layer_idx
,
scaling_factor
in
kv_cache_scales_loader
(
quantization_param_path
,
tp_rank
,
tp_size
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
__class__
.
model_type
,
):
if
not
isinstance
(
self
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
layers
[
layer_idx
].
self_attn
if
hasattr
(
layer_self_attn
.
attn
,
"k_scale"
):
layer_self_attn
.
attn
.
k_scale
=
scaling_factor
layer_self_attn
.
attn
.
v_scale
=
scaling_factor
else
:
raise
RuntimeError
(
"Self attention has no KV cache scaling "
"factor attribute!"
)
class
Qwen2ForCausalLM
(
nn
.
Module
):
...
...
@@ -373,5 +402,8 @@ class Qwen2ForCausalLM(nn.Module):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
EntryClass
=
Qwen2ForCausalLM
python/sglang/test/test_utils.py
View file @
d3024f4f
...
...
@@ -40,6 +40,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mis
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1
=
"neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2
=
"neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1
=
"hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN
=
"Qwen/Qwen2.5-1.5B-Instruct"
def
is_in_ci
():
...
...
test/srt/kv_cache_scales_llama3_8b.json
0 → 100644
View file @
d3024f4f
{
"model_type"
:
"llama"
,
"kv_cache"
:
{
"dtype"
:
"float8_e4m3fn"
,
"scaling_factor"
:
{
"0"
:
{
"0"
:
0.0408
,
"1"
:
0.0503
,
"2"
:
0.0667
,
"3"
:
0.0909
,
"4"
:
0.1135
,
"5"
:
0.127
,
"6"
:
0.1768
,
"7"
:
0.1488
,
"8"
:
0.1135
,
"9"
:
0.1203
,
"10"
:
0.1013
,
"11"
:
0.0842
,
"12"
:
0.1231
,
"13"
:
0.1096
,
"14"
:
0.1221
,
"15"
:
0.1013
,
"16"
:
0.1067
,
"17"
:
0.0952
,
"18"
:
0.0899
,
"19"
:
0.097
,
"20"
:
0.087
,
"21"
:
0.0994
,
"22"
:
0.0904
,
"23"
:
0.1013
,
"24"
:
0.1019
,
"25"
:
0.1053
,
"26"
:
0.1
,
"27"
:
0.0894
,
"28"
:
0.1013
,
"29"
:
0.1488
,
"30"
:
0.0766
,
"31"
:
0.0821
}
}
}
}
test/srt/kv_cache_scales_qwen2_1_5b.json
0 → 100644
View file @
d3024f4f
{
"model_type"
:
"qwen"
,
"kv_cache"
:
{
"dtype"
:
"float8_e4m3fn"
,
"scaling_factor"
:
{
"0"
:
{
"0"
:
0.9846
,
"1"
:
0.0645
,
"2"
:
0.0731
,
"3"
:
0.0800
,
"4"
:
0.0748
,
"5"
:
0.0780
,
"6"
:
0.0702
,
"7"
:
0.0894
,
"8"
:
0.0410
,
"9"
:
0.0758
,
"10"
:
0.0556
,
"11"
:
0.0731
,
"12"
:
0.0899
,
"13"
:
0.0780
,
"14"
:
0.1441
,
"15"
:
0.0914
,
"16"
:
0.5614
,
"17"
:
0.1067
,
"18"
:
0.0537
,
"19"
:
0.0658
,
"20"
:
0.0523
,
"21"
:
0.0533
,
"22"
:
0.0699
,
"23"
:
0.0635
,
"24"
:
0.0588
,
"25"
:
0.0884
,
"26"
:
0.0947
,
"27"
:
0.1032
}
}
}
}
test/srt/run_suite.py
View file @
d3024f4f
...
...
@@ -52,6 +52,7 @@ suites = {
"test_vision_openai_server.py"
,
"test_w8a8_quantization.py"
,
"test_session_control.py"
,
"test_fp8_kvcache.py"
,
],
"nightly"
:
[
"test_nightly_gsm8k_eval.py"
,
...
...
test/srt/test_fp8_kvcache.py
View file @
d3024f4f
...
...
@@ -6,19 +6,26 @@ from sglang.srt.utils import kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestFp8Kvcache
(
unittest
.
TestCase
):
class
TestFp8KvcacheBase
(
unittest
.
TestCase
):
model_config
=
None
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
if
cls
.
model_config
is
None
:
raise
NotImplementedError
(
"model_config must be specified in subclass"
)
cls
.
model
=
cls
.
model_config
[
"model_name"
]
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
dirpath
=
os
.
path
.
dirname
(
__file__
)
config_file
=
os
.
path
.
join
(
dirpath
,
"kv_cache_scales_llama3_8b_chat.json"
)
config_file
=
os
.
path
.
join
(
dirpath
,
cls
.
model_config
[
"config_filename"
])
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
...
...
@@ -31,6 +38,13 @@ class TestFp8Kvcache(unittest.TestCase):
],
)
class
TestFp8KvcacheLlama
(
TestFp8KvcacheBase
):
model_config
=
{
"model_name"
:
DEFAULT_MODEL_NAME_FOR_TEST
,
"config_filename"
:
"kv_cache_scales_llama3_8b.json"
,
}
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
...
...
@@ -45,7 +59,7 @@ class TestFp8Kvcache(unittest.TestCase):
)
metrics
=
run_eval
(
args
)
self
.
assertGreater
(
metrics
[
"score"
],
0.8
35
)
self
.
assertGreater
(
metrics
[
"score"
],
0.8
0
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
...
...
@@ -60,5 +74,40 @@ class TestFp8Kvcache(unittest.TestCase):
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.65
)
class
TestFp8KvcacheQwen
(
TestFp8KvcacheBase
):
model_config
=
{
"model_name"
:
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN
,
"config_filename"
:
"kv_cache_scales_qwen2_1_5b.json"
,
}
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_mgsm_en
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mgsm_en"
,
num_examples
=
None
,
num_threads
=
1024
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreater
(
metrics
[
"score"
],
0.01
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.3
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment