Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b7e2f800
Unverified
Commit
b7e2f800
authored
Jun 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Jun 20, 2024
Browse files
Update flashinfer to 0.0.5 (#554)
parent
09593e9b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
46 deletions
+76
-46
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+13
-4
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+62
-41
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
b7e2f800
...
@@ -12,7 +12,8 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
...
@@ -12,7 +12,8 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class
RadixAttention
(
nn
.
Module
):
class
RadixAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
,
logit_cap
=-
1
self
,
num_heads
:
int
,
head_dim
:
int
,
scaling
:
float
,
num_kv_heads
:
int
,
layer_id
:
int
,
logit_cap
:
int
=
-
1
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
self
.
tp_q_head_num
=
num_heads
...
@@ -20,7 +21,6 @@ class RadixAttention(nn.Module):
...
@@ -20,7 +21,6 @@ class RadixAttention(nn.Module):
self
.
tp_v_head_num
=
num_kv_heads
self
.
tp_v_head_num
=
num_kv_heads
self
.
head_dim
=
head_dim
self
.
head_dim
=
head_dim
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
logit_cap
=
logit_cap
assert
np
.
allclose
(
scaling
,
1.0
/
(
head_dim
**
0.5
))
assert
np
.
allclose
(
scaling
,
1.0
/
(
head_dim
**
0.5
))
...
@@ -30,10 +30,17 @@ class RadixAttention(nn.Module):
...
@@ -30,10 +30,17 @@ class RadixAttention(nn.Module):
self
.
prefill_forward
=
self
.
prefill_forward_flashinfer
self
.
prefill_forward
=
self
.
prefill_forward_flashinfer
self
.
extend_forward
=
self
.
prefill_forward_flashinfer
self
.
extend_forward
=
self
.
prefill_forward_flashinfer
self
.
decode_forward
=
self
.
decode_forward_flashinfer
self
.
decode_forward
=
self
.
decode_forward_flashinfer
# flashinfer only accepts a boolean logit_cap argument
if
logit_cap
>
0
:
assert
logit_cap
==
30
self
.
logit_cap
=
True
else
:
self
.
logit_cap
=
False
else
:
else
:
self
.
prefill_forward
=
self
.
prefill_forward_triton
self
.
prefill_forward
=
self
.
prefill_forward_triton
self
.
extend_forward
=
self
.
extend_forward_triton
self
.
extend_forward
=
self
.
extend_forward_triton
self
.
decode_forward
=
self
.
decode_forward_triton
self
.
decode_forward
=
self
.
decode_forward_triton
self
.
logit_cap
=
logit_cap
def
prefill_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
prefill_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
...
@@ -100,9 +107,10 @@ class RadixAttention(nn.Module):
...
@@ -100,9 +107,10 @@ class RadixAttention(nn.Module):
def
prefill_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
prefill_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
input_metadata
.
prefill_wrapper
.
forward
(
o
=
input_metadata
.
flashinfer_
prefill_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
],
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
],
logits_cap
=
self
.
logit_cap
,
)
)
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
...
@@ -110,9 +118,10 @@ class RadixAttention(nn.Module):
...
@@ -110,9 +118,10 @@ class RadixAttention(nn.Module):
def
decode_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
decode_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
input_metadata
.
decode_wrapper
.
forward
(
o
=
input_metadata
.
flashinfer_
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
],
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
],
logits_cap
=
self
.
logit_cap
,
)
)
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
b7e2f800
...
@@ -6,7 +6,7 @@ import logging
...
@@ -6,7 +6,7 @@ import logging
import
pkgutil
import
pkgutil
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Type
from
typing
import
List
,
Optional
,
Type
,
Any
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -34,7 +34,6 @@ global_server_args_dict = {}
...
@@ -34,7 +34,6 @@ global_server_args_dict = {}
@
dataclass
@
dataclass
class
InputMetadata
:
class
InputMetadata
:
model_runner
:
"ModelRunner"
forward_mode
:
ForwardMode
forward_mode
:
ForwardMode
batch_size
:
int
batch_size
:
int
total_num_tokens
:
int
total_num_tokens
:
int
...
@@ -65,15 +64,10 @@ class InputMetadata:
...
@@ -65,15 +64,10 @@ class InputMetadata:
kv_indptr
:
torch
.
Tensor
=
None
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
kv_last_page_len
:
torch
.
Tensor
=
None
kv_last_page_len
:
torch
.
Tensor
=
None
prefill_wrapper
=
None
flashinfer_prefill_wrapper
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
decode_wrapper
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
tp_size
):
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
)
def
init_flashinfer_args
(
self
,
num_attention_heads
,
num_key_value_heads
,
head_dim
):
self
.
kv_indptr
=
torch
.
zeros
(
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
...
@@ -93,9 +87,6 @@ class InputMetadata:
...
@@ -93,9 +87,6 @@ class InputMetadata:
dim
=
0
,
dim
=
0
,
).
contiguous
()
).
contiguous
()
workspace_buffer
=
torch
.
empty
(
32
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
if
(
if
(
self
.
forward_mode
==
ForwardMode
.
PREFILL
self
.
forward_mode
==
ForwardMode
.
PREFILL
or
self
.
forward_mode
==
ForwardMode
.
EXTEND
or
self
.
forward_mode
==
ForwardMode
.
EXTEND
...
@@ -104,34 +95,30 @@ class InputMetadata:
...
@@ -104,34 +95,30 @@ class InputMetadata:
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
self
.
flashinfer_prefill_wrapper
.
end_forward
()
)
self
.
flashinfer_prefill_wrapper
.
begin_forward
(
args
=
[
self
.
qo_indptr
,
self
.
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
kv_last_page_len
,
self
.
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
num_attention_heads
,
self
.
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
num_key_value_heads
,
self
.
model_runner
.
model_config
.
head_dim
,
head_dim
,
]
1
self
.
prefill_wrapper
.
begin_forward
(
*
args
)
else
:
self
.
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
)
self
.
decode_wrapper
.
begin_forward
(
else
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
kv_last_page_len
,
self
.
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
num_attention_heads
,
self
.
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
num_key_value_heads
,
self
.
model_runner
.
model_config
.
head_dim
,
head_dim
,
1
,
1
,
"NONE"
,
pos_encoding_mode
=
"NONE"
,
"float16"
,
data_type
=
"float16"
,
)
)
def
init_extend_args
(
self
):
def
init_extend_args
(
self
):
...
@@ -155,6 +142,8 @@ class InputMetadata:
...
@@ -155,6 +142,8 @@ class InputMetadata:
out_cache_cont_end
=
None
,
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
return_logprob
=
False
,
flashinfer_prefill_wrapper
=
None
,
flashinfer_decode_wrapper
=
None
,
):
):
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
@@ -187,7 +176,6 @@ class InputMetadata:
...
@@ -187,7 +176,6 @@ class InputMetadata:
other_kv_index
=
None
other_kv_index
=
None
ret
=
cls
(
ret
=
cls
(
model_runner
=
model_runner
,
forward_mode
=
forward_mode
,
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
total_num_tokens
=
total_num_tokens
,
total_num_tokens
=
total_num_tokens
,
...
@@ -205,13 +193,19 @@ class InputMetadata:
...
@@ -205,13 +193,19 @@ class InputMetadata:
other_kv_index
=
other_kv_index
,
other_kv_index
=
other_kv_index
,
return_logprob
=
return_logprob
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper
=
flashinfer_prefill_wrapper
,
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
)
)
if
forward_mode
==
ForwardMode
.
EXTEND
:
if
forward_mode
==
ForwardMode
.
EXTEND
:
ret
.
init_extend_args
()
ret
.
init_extend_args
()
if
global_server_args_dict
.
get
(
"enable_flashinfer"
,
False
):
if
global_server_args_dict
.
get
(
"enable_flashinfer"
,
False
):
ret
.
init_flashinfer_args
(
tp_size
)
ret
.
init_flashinfer_args
(
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
model_runner
.
model_config
.
head_dim
)
return
ret
return
ret
...
@@ -234,12 +228,7 @@ class ModelRunner:
...
@@ -234,12 +228,7 @@ class ModelRunner:
self
.
tp_size
=
tp_size
self
.
tp_size
=
tp_size
self
.
nccl_port
=
nccl_port
self
.
nccl_port
=
nccl_port
self
.
server_args
=
server_args
self
.
server_args
=
server_args
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
global
global_server_args_dict
global_server_args_dict
=
{
"enable_flashinfer"
:
server_args
.
enable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
# Init torch distributed
# Init torch distributed
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Set cuda device."
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Set cuda device."
)
...
@@ -269,9 +258,17 @@ class ModelRunner:
...
@@ -269,9 +258,17 @@ class ModelRunner:
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
)
# Set some global args
global
global_server_args_dict
global_server_args_dict
=
{
"enable_flashinfer"
:
server_args
.
enable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
# Load the model and create memory pool
self
.
load_model
()
self
.
load_model
()
self
.
init_memory_pool
(
total_gpu_memory
)
self
.
init_memory_pool
(
total_gpu_memory
)
self
.
i
s_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
self
.
i
nit_flash_infer
(
)
def
load_model
(
self
):
def
load_model
(
self
):
logger
.
info
(
logger
.
info
(
...
@@ -347,6 +344,22 @@ class ModelRunner:
...
@@ -347,6 +344,22 @@ class ModelRunner:
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
)
def
init_flash_infer
(
self
):
if
global_server_args_dict
.
get
(
"enable_flashinfer"
,
False
):
from
flashinfer
import
(
BatchPrefillWithPagedKVCacheWrapper
,
BatchDecodeWithPagedKVCacheWrapper
,
)
workspace_buffer
=
torch
.
empty
(
32
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
self
.
flashinfer_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_prefill
(
self
,
batch
:
Batch
):
def
forward_prefill
(
self
,
batch
:
Batch
):
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
...
@@ -360,6 +373,8 @@ class ModelRunner:
...
@@ -360,6 +373,8 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
@@ -378,6 +393,8 @@ class ModelRunner:
...
@@ -378,6 +393,8 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
@@ -398,6 +415,8 @@ class ModelRunner:
...
@@ -398,6 +415,8 @@ class ModelRunner:
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
@@ -416,6 +435,8 @@ class ModelRunner:
...
@@ -416,6 +435,8 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper
=
self
.
flashinfer_prefill_wrapper
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
batch
.
input_ids
,
...
...
python/sglang/srt/server.py
View file @
b7e2f800
...
@@ -150,7 +150,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -150,7 +150,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
server_args
.
disable_disk_cache
:
if
server_args
.
disable_disk_cache
:
disable_cache
()
disable_cache
()
if
server_args
.
enable_flashinfer
:
if
server_args
.
enable_flashinfer
:
assert_pkg_version
(
"flashinfer"
,
"0.0.
4
"
)
assert_pkg_version
(
"flashinfer"
,
"0.0.
5
"
)
if
server_args
.
chat_template
:
if
server_args
.
chat_template
:
# TODO: replace this with huggingface transformers template
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment