Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b7e2f800
Unverified
Commit
b7e2f800
authored
Jun 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Jun 20, 2024
Browse files
Update flashinfer to 0.0.5 (#554)
parent
09593e9b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
46 deletions
+76
-46
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+13
-4
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+62
-41
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
b7e2f800
...
...
@@ -12,7 +12,8 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class
RadixAttention
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
,
logit_cap
=-
1
self
,
num_heads
:
int
,
head_dim
:
int
,
scaling
:
float
,
num_kv_heads
:
int
,
layer_id
:
int
,
logit_cap
:
int
=
-
1
):
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
...
...
@@ -20,7 +21,6 @@ class RadixAttention(nn.Module):
self
.
tp_v_head_num
=
num_kv_heads
self
.
head_dim
=
head_dim
self
.
layer_id
=
layer_id
self
.
logit_cap
=
logit_cap
assert
np
.
allclose
(
scaling
,
1.0
/
(
head_dim
**
0.5
))
...
...
@@ -30,10 +30,17 @@ class RadixAttention(nn.Module):
self
.
prefill_forward
=
self
.
prefill_forward_flashinfer
self
.
extend_forward
=
self
.
prefill_forward_flashinfer
self
.
decode_forward
=
self
.
decode_forward_flashinfer
# flashinfer only accepts a boolean logit_cap argument
if
logit_cap
>
0
:
assert
logit_cap
==
30
self
.
logit_cap
=
True
else
:
self
.
logit_cap
=
False
else
:
self
.
prefill_forward
=
self
.
prefill_forward_triton
self
.
extend_forward
=
self
.
extend_forward_triton
self
.
decode_forward
=
self
.
decode_forward_triton
self
.
logit_cap
=
logit_cap
def
prefill_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o
=
torch
.
empty_like
(
q
)
...
...
@@ -100,9 +107,10 @@ class RadixAttention(nn.Module):
def
prefill_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
input_metadata
.
prefill_wrapper
.
forward
(
o
=
input_metadata
.
flashinfer_
prefill_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
],
logits_cap
=
self
.
logit_cap
,
)
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
...
...
@@ -110,9 +118,10 @@ class RadixAttention(nn.Module):
def
decode_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
input_metadata
.
decode_wrapper
.
forward
(
o
=
input_metadata
.
flashinfer_
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
],
logits_cap
=
self
.
logit_cap
,
)
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
b7e2f800
...
...
@@ -6,7 +6,7 @@ import logging
import
pkgutil
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Type
from
typing
import
List
,
Optional
,
Type
,
Any
import
numpy
as
np
import
torch
...
...
@@ -34,7 +34,6 @@ global_server_args_dict = {}
@
dataclass
class
InputMetadata
:
model_runner
:
"ModelRunner"
forward_mode
:
ForwardMode
batch_size
:
int
total_num_tokens
:
int
...
...
@@ -65,15 +64,10 @@ class InputMetadata:
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
kv_last_page_len
:
torch
.
Tensor
=
None
prefill_wrapper
=
None
decode_wrapper
=
None
def
init_flashinfer_args
(
self
,
tp_size
):
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
)
flashinfer_prefill_wrapper
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
num_attention_heads
,
num_key_value_heads
,
head_dim
):
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
@@ -93,9 +87,6 @@ class InputMetadata:
dim
=
0
,
).
contiguous
()
workspace_buffer
=
torch
.
empty
(
32
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
if
(
self
.
forward_mode
==
ForwardMode
.
PREFILL
or
self
.
forward_mode
==
ForwardMode
.
EXTEND
...
...
@@ -104,34 +95,30 @@ class InputMetadata:
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
args
=
[
self
.
flashinfer_prefill_wrapper
.
end_forward
()
self
.
flashinfer_prefill_wrapper
.
begin_forward
(
self
.
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
head_dim
,
]
self
.
prefill_wrapper
.
begin_forward
(
*
args
)
else
:
self
.
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
num_attention_heads
,
num_key_value_heads
,
head_dim
,
1
)
self
.
decode_wrapper
.
begin_forward
(
else
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
head_dim
,
num_attention_heads
,
num_key_value_heads
,
head_dim
,
1
,
"NONE"
,
"float16"
,
pos_encoding_mode
=
"NONE"
,
data_type
=
"float16"
,
)
def
init_extend_args
(
self
):
...
...
@@ -155,6 +142,8 @@ class InputMetadata:
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
flashinfer_prefill_wrapper
=
None
,
flashinfer_decode_wrapper
=
None
,
):
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
@@ -187,7 +176,6 @@ class InputMetadata:
other_kv_index
=
None
ret
=
cls
(
model_runner
=
model_runner
,
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
total_num_tokens
=
total_num_tokens
,
...
...
@@ -205,13 +193,19 @@ class InputMetadata:
other_kv_index
=
other_kv_index
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper
=
flashinfer_prefill_wrapper
,
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
)
if
forward_mode
==
ForwardMode
.
EXTEND
:
ret
.
init_extend_args
()
if
global_server_args_dict
.
get
(
"enable_flashinfer"
,
False
):
ret
.
init_flashinfer_args
(
tp_size
)
ret
.
init_flashinfer_args
(
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
model_runner
.
model_config
.
head_dim
)
return
ret
...
...
@@ -234,12 +228,7 @@ class ModelRunner:
self
.
tp_size
=
tp_size
self
.
nccl_port
=
nccl_port
self
.
server_args
=
server_args
global
global_server_args_dict
global_server_args_dict
=
{
"enable_flashinfer"
:
server_args
.
enable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
# Init torch distributed
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Set cuda device."
)
...
...
@@ -269,9 +258,17 @@ class ModelRunner:
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
# Set some global args
global
global_server_args_dict
global_server_args_dict
=
{
"enable_flashinfer"
:
server_args
.
enable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
# Load the model and create memory pool
self
.
load_model
()
self
.
init_memory_pool
(
total_gpu_memory
)
self
.
i
s_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
self
.
i
nit_flash_infer
(
)
def
load_model
(
self
):
logger
.
info
(
...
...
@@ -347,6 +344,22 @@ class ModelRunner:
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
def
init_flash_infer
(
self
):
if
global_server_args_dict
.
get
(
"enable_flashinfer"
,
False
):
from
flashinfer
import
(
BatchPrefillWithPagedKVCacheWrapper
,
BatchDecodeWithPagedKVCacheWrapper
,
)
workspace_buffer
=
torch
.
empty
(
32
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
self
.
flashinfer_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
@
torch
.
inference_mode
()
def
forward_prefill
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
...
...
@@ -360,6 +373,8 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
...
@@ -378,6 +393,8 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
...
@@ -398,6 +415,8 @@ class ModelRunner:
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
...
@@ -416,6 +435,8 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
...
...
python/sglang/srt/server.py
View file @
b7e2f800
...
...
@@ -150,7 +150,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
server_args
.
disable_disk_cache
:
disable_cache
()
if
server_args
.
enable_flashinfer
:
assert_pkg_version
(
"flashinfer"
,
"0.0.
4
"
)
assert_pkg_version
(
"flashinfer"
,
"0.0.
5
"
)
if
server_args
.
chat_template
:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment