Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0bb0f763
Unverified
Commit
0bb0f763
authored
Jan 13, 2025
by
bjmsong
Committed by
GitHub
Jan 12, 2025
Browse files
Support FP8 E4M3 KV Cache (#2786)
Co-authored-by:
root
<
bjmsong@126.com
>
parent
85b2e057
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
205 additions
and
10 deletions
+205
-10
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+13
-3
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+6
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+27
-0
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+32
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+6
-0
test/srt/kv_cache_scales_llama3_1_8b.json
test/srt/kv_cache_scales_llama3_1_8b.json
+42
-0
test/srt/test_fp8_kvcache.py
test/srt/test_fp8_kvcache.py
+64
-0
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
0bb0f763
...
...
@@ -353,7 +353,9 @@ class FlashInferAttnBackend(AttentionBackend):
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
...
...
@@ -362,6 +364,8 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale
=
layer
.
scaling
,
window_left
=
layer
.
sliding_window_size
,
logits_soft_cap
=
logits_soft_cap
,
k_scale
=
layer
.
k_scale
,
v_scale
=
layer
.
v_scale
,
)
else
:
o1
,
s1
=
self
.
prefill_wrapper_ragged
.
forward_return_lse
(
...
...
@@ -387,7 +391,9 @@ class FlashInferAttnBackend(AttentionBackend):
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
...
...
@@ -412,13 +418,17 @@ class FlashInferAttnBackend(AttentionBackend):
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
k_scale
=
layer
.
k_scale
,
v_scale
=
layer
.
v_scale
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
...
...
python/sglang/srt/layers/radix_attention.py
View file @
0bb0f763
...
...
@@ -47,6 +47,8 @@ class RadixAttention(nn.Module):
self
.
logit_cap
=
logit_cap
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
is_cross_attention
=
is_cross_attention
self
.
k_scale
=
1.0
self
.
v_scale
=
1.0
def
forward
(
self
,
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
0bb0f763
...
...
@@ -109,8 +109,8 @@ class BaseTokenToKVPool:
):
self
.
size
=
size
self
.
dtype
=
dtype
if
dtype
==
torch
.
float8_e5m2
:
# NOTE: Store as torch.uint8 because Tensor
index_put is not implemented for torch.float8_e5m2
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
)
:
# NOTE: Store as torch.uint8 because Tensor
.
index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
...
...
@@ -256,11 +256,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
):
layer_id
=
layer
.
layer_id
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
cache_v
=
cache_v
.
to
(
self
.
dtype
)
cache_k
=
(
cache_k
/
k_scale
)
.
to
(
self
.
dtype
)
cache_v
=
(
cache_v
/
v_scale
)
.
to
(
self
.
dtype
)
if
self
.
store_dtype
!=
self
.
dtype
:
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
.
view
(
self
.
store_dtype
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
0bb0f763
...
...
@@ -54,6 +54,7 @@ from sglang.srt.utils import (
enable_show_time_cost
,
get_available_gpu_memory
,
init_custom_process_group
,
is_cuda
,
is_hip
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_p2p_access_check
,
...
...
@@ -277,6 +278,29 @@ class ModelRunner:
device_config
=
DeviceConfig
(
self
.
device
),
)
if
self
.
server_args
.
kv_cache_dtype
==
"fp8_e4m3"
:
if
self
.
server_args
.
quantization_param_path
is
not
None
:
if
callable
(
getattr
(
self
.
model
,
"load_kv_cache_scales"
,
None
)):
self
.
model
.
load_kv_cache_scales
(
self
.
server_args
.
quantization_param_path
)
logger
.
info
(
"Loaded KV cache scaling factors from %s"
,
self
.
server_args
.
quantization_param_path
,
)
else
:
raise
RuntimeError
(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors."
,
self
.
model
.
__class__
,
)
else
:
logger
.
warning
(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
# Parse other args
self
.
sliding_window_size
=
(
self
.
model
.
get_attention_sliding_window_size
()
...
...
@@ -516,6 +540,9 @@ class ModelRunner:
self
.
kv_cache_dtype
=
torch
.
float8_e5m2fnuz
else
:
self
.
kv_cache_dtype
=
torch
.
float8_e5m2
elif
self
.
server_args
.
kv_cache_dtype
==
"fp8_e4m3"
:
if
is_cuda
():
self
.
kv_cache_dtype
=
torch
.
float8_e4m3fn
else
:
raise
ValueError
(
f
"Unsupported kv_cache_dtype:
{
self
.
server_args
.
kv_cache_dtype
}
."
...
...
python/sglang/srt/models/llama.py
View file @
0bb0f763
...
...
@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
kv_cache_scales_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -299,6 +303,30 @@ class LlamaModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
layer_idx
,
scaling_factor
in
kv_cache_scales_loader
(
quantization_param_path
,
tp_rank
,
tp_size
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
__class__
.
model_type
,
):
if
not
isinstance
(
self
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
layers
[
layer_idx
].
self_attn
if
hasattr
(
layer_self_attn
.
attn
,
"k_scale"
):
layer_self_attn
.
attn
.
k_scale
=
scaling_factor
layer_self_attn
.
attn
.
v_scale
=
scaling_factor
else
:
raise
RuntimeError
(
"Self attention has no KV cache scaling "
"factor attribute!"
)
class
LlamaForCausalLM
(
nn
.
Module
):
...
...
@@ -534,6 +562,9 @@ class LlamaForCausalLM(nn.Module):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
class
Phi3ForCausalLM
(
LlamaForCausalLM
):
pass
...
...
python/sglang/srt/server_args.py
View file @
0bb0f763
...
...
@@ -32,6 +32,7 @@ from sglang.srt.utils import (
is_hip
,
is_ipv6
,
is_port_available
,
nullable_str
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -47,6 +48,7 @@ class ServerArgs:
trust_remote_code
:
bool
=
True
dtype
:
str
=
"auto"
kv_cache_dtype
:
str
=
"auto"
quantization_param_path
:
nullable_str
=
None
quantization
:
Optional
[
str
]
=
None
context_length
:
Optional
[
int
]
=
None
device
:
str
=
"cuda"
...
...
@@ -350,8 +352,17 @@ class ServerArgs:
"--kv-cache-dtype"
,
type
=
str
,
default
=
ServerArgs
.
kv_cache_dtype
,
choices
=
[
"auto"
,
"fp8_e5m2"
],
help
=
'Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.'
,
choices
=
[
"auto"
,
"fp8_e5m2"
,
"fp8_e4m3"
],
help
=
'Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.'
,
)
parser
.
add_argument
(
"--quantization-param-path"
,
type
=
nullable_str
,
default
=
None
,
help
=
"Path to the JSON file containing the KV cache "
"scaling factors. This should generally be supplied, when "
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
"default to 1.0, which may cause accuracy issues. "
,
)
parser
.
add_argument
(
"--quantization"
,
...
...
python/sglang/srt/utils.py
View file @
0bb0f763
...
...
@@ -1375,3 +1375,9 @@ def debug_timing(func):
return
func
(
*
args
,
**
kwargs
)
return
wrapper
def
nullable_str
(
val
:
str
):
if
not
val
or
val
==
"None"
:
return
None
return
val
test/srt/kv_cache_scales_llama3_1_8b.json
0 → 100644
View file @
0bb0f763
{
"model_type"
:
"llama"
,
"kv_cache"
:
{
"dtype"
:
"float8_e4m3fn"
,
"scaling_factor"
:
{
"0"
:
{
"0"
:
1
,
"1"
:
1
,
"2"
:
1
,
"3"
:
1
,
"4"
:
1
,
"5"
:
1
,
"6"
:
1
,
"7"
:
1
,
"8"
:
1
,
"9"
:
1
,
"10"
:
1
,
"11"
:
1
,
"12"
:
1
,
"13"
:
1
,
"14"
:
1
,
"15"
:
1
,
"16"
:
1
,
"17"
:
1
,
"18"
:
1
,
"19"
:
1
,
"20"
:
1
,
"21"
:
1
,
"22"
:
1
,
"23"
:
1
,
"24"
:
1
,
"25"
:
1
,
"26"
:
1
,
"27"
:
1
,
"28"
:
1
,
"29"
:
1
,
"30"
:
1
,
"31"
:
1
}
}
}
}
test/srt/test_fp8_kvcache.py
0 → 100644
View file @
0bb0f763
import
os
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestFp8Kvcache
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
dirpath
=
os
.
path
.
dirname
(
__file__
)
config_file
=
os
.
path
.
join
(
dirpath
,
"kv_cache_scales_llama3_8b_chat.json"
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--kv-cache-dtype"
,
"fp8_e4m3"
,
"--quantization-param-path"
,
config_file
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_mgsm_en
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mgsm_en"
,
num_examples
=
None
,
num_threads
=
1024
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreater
(
metrics
[
"score"
],
0.835
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.65
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment