Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2c615d12
Unverified
Commit
2c615d12
authored
Aug 26, 2024
by
Ke Bao
Committed by
GitHub
Aug 25, 2024
Browse files
[Feature] Support fp8 e5m2 kv cache with flashinfer (#1204)
Co-authored-by:
Yineng Zhang
<
me@zhyncs.com
>
parent
61bb223e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
116 additions
and
16 deletions
+116
-16
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+3
-4
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+82
-8
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+4
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+19
-4
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
2c615d12
...
...
@@ -203,7 +203,6 @@ class RadixAttention(nn.Module):
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
def
store_kv_cache
(
self
,
cache_k
,
cache_v
,
input_metadata
:
InputMetadata
):
k_cache
=
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
)
v_cache
=
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
)
k_cache
[
input_metadata
.
out_cache_loc
]
=
cache_k
v_cache
[
input_metadata
.
out_cache_loc
]
=
cache_v
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
layer_id
,
input_metadata
.
out_cache_loc
,
cache_k
,
cache_v
)
python/sglang/srt/mem_cache/memory_pool.py
View file @
2c615d12
...
...
@@ -16,7 +16,8 @@ limitations under the License.
"""Memory pool."""
import
logging
from
typing
import
List
,
Union
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Tuple
,
Union
import
torch
...
...
@@ -52,14 +53,21 @@ class ReqToTokenPool:
self
.
free_slots
=
list
(
range
(
self
.
size
))
class
BaseTokenToKVPool
:
class
BaseTokenToKVPool
(
ABC
)
:
"""A memory pool that maps a token to its kv cache locations"""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
):
self
.
size
=
size
self
.
dtype
=
dtype
if
dtype
==
torch
.
float8_e5m2
:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
=
torch
.
ones
((
self
.
size
+
1
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
...
...
@@ -112,6 +120,28 @@ class BaseTokenToKVPool:
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
[
0
]
=
False
@
abstractmethod
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
@
abstractmethod
def
get_value_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
@
abstractmethod
def
get_kv_buffer
(
self
,
layer_id
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
()
@
abstractmethod
def
set_kv_buffer
(
self
,
layer_id
:
int
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
)
->
None
:
raise
NotImplementedError
()
class
MHATokenToKVPool
(
BaseTokenToKVPool
):
...
...
@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool):
head_dim
:
int
,
layer_num
:
int
,
):
super
().
__init__
(
size
)
super
().
__init__
(
size
,
dtype
)
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
torch
.
empty
(
(
size
+
1
,
head_num
,
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
self
.
v_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
torch
.
empty
(
(
size
+
1
,
head_num
,
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
v_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]
return
self
.
get_key_buffer
(
layer_id
),
self
.
get_value_buffer
(
layer_id
)
def
set_kv_buffer
(
self
,
layer_id
:
int
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
):
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
if
cache_v
.
dtype
!=
self
.
dtype
:
cache_v
=
cache_v
.
to
(
self
.
dtype
)
if
self
.
store_dtype
!=
self
.
dtype
:
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
.
view
(
self
.
store_dtype
)
else
:
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
class
MLATokenToKVPool
(
BaseTokenToKVPool
):
...
...
@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool):
qk_rope_head_dim
:
int
,
layer_num
:
int
,
):
super
().
__init__
(
size
)
super
().
__init__
(
size
,
dtype
)
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_buffer
=
[
torch
.
empty
(
(
size
+
1
,
1
,
kv_lora_rank
+
qk_rope_head_dim
),
dtype
=
dtype
,
dtype
=
self
.
store_
dtype
,
device
=
"cuda"
,
)
for
_
in
range
(
layer_num
)
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
kv_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
kv_buffer
[
layer_id
][...,
:
self
.
kv_lora_rank
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
][...,
:
self
.
kv_lora_rank
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
get_key_buffer
(
layer_id
),
self
.
get_value_buffer
(
layer_id
)
def
set_kv_buffer
(
self
,
layer_id
:
int
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
):
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
if
self
.
store_dtype
!=
self
.
dtype
:
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
else
:
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
python/sglang/srt/model_executor/forward_batch_info.py
View file @
2c615d12
...
...
@@ -315,6 +315,8 @@ def update_flashinfer_indices(
num_kv_heads
,
head_dim
,
1
,
data_type
=
model_runner
.
kv_cache_dtype
,
q_data_type
=
model_runner
.
dtype
,
)
else
:
# extend part
...
...
@@ -393,6 +395,8 @@ def update_flashinfer_indices(
num_kv_heads
,
head_dim
,
1
,
data_type
=
model_runner
.
kv_cache_dtype
,
q_data_type
=
model_runner
.
dtype
,
)
else
:
# extend part
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
2c615d12
...
...
@@ -311,7 +311,7 @@ class ModelRunner:
cell_size
=
(
(
self
.
model_config
.
kv_lora_rank
+
self
.
model_config
.
qk_rope_head_dim
)
*
self
.
model_config
.
num_hidden_layers
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
*
torch
.
_utils
.
_element_size
(
self
.
kv_cache_
dtype
)
)
else
:
cell_size
=
(
...
...
@@ -319,7 +319,7 @@ class ModelRunner:
*
self
.
model_config
.
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
*
torch
.
_utils
.
_element_size
(
self
.
kv_cache_
dtype
)
)
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
1
-
self
.
mem_fraction_static
...
...
@@ -333,6 +333,21 @@ class ModelRunner:
max_num_reqs
:
int
=
None
,
max_total_tokens
:
int
=
None
,
):
if
self
.
server_args
.
kv_cache_dtype
==
"auto"
:
self
.
kv_cache_dtype
=
self
.
dtype
elif
self
.
server_args
.
kv_cache_dtype
==
"fp8_e5m2"
:
if
self
.
server_args
.
disable_flashinfer
or
self
.
server_args
.
enable_mla
:
logger
.
warning
(
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
)
self
.
kv_cache_dtype
=
self
.
dtype
else
:
self
.
kv_cache_dtype
=
torch
.
float8_e5m2
else
:
raise
ValueError
(
f
"Unsupported kv_cache_dtype:
{
self
.
server_args
.
kv_cache_dtype
}
."
)
self
.
max_total_num_tokens
=
self
.
profile_max_num_token
(
total_gpu_memory
)
if
max_total_tokens
is
not
None
:
if
max_total_tokens
>
self
.
max_total_num_tokens
:
...
...
@@ -369,7 +384,7 @@ class ModelRunner:
):
self
.
token_to_kv_pool
=
MLATokenToKVPool
(
self
.
max_total_num_tokens
,
dtype
=
self
.
dtype
,
dtype
=
self
.
kv_cache_
dtype
,
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
...
...
@@ -380,7 +395,7 @@ class ModelRunner:
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
max_total_num_tokens
,
dtype
=
self
.
dtype
,
dtype
=
self
.
kv_cache_
dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
...
...
python/sglang/srt/server_args.py
View file @
2c615d12
...
...
@@ -33,6 +33,7 @@ class ServerArgs:
skip_tokenizer_init
:
bool
=
False
load_format
:
str
=
"auto"
dtype
:
str
=
"auto"
kv_cache_dtype
:
str
=
"auto"
trust_remote_code
:
bool
=
True
context_length
:
Optional
[
int
]
=
None
quantization
:
Optional
[
str
]
=
None
...
...
@@ -196,6 +197,13 @@ class ServerArgs:
'* "float" is shorthand for FP32 precision.
\n
'
'* "float32" for FP32 precision.'
,
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
type
=
str
,
default
=
ServerArgs
.
kv_cache_dtype
,
choices
=
[
"auto"
,
"fp8_e5m2"
],
help
=
'Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.'
,
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment