Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e7487b08
Unverified
Commit
e7487b08
authored
Jul 30, 2024
by
Ying Sheng
Committed by
GitHub
Jul 30, 2024
Browse files
Adjust default mem fraction to avoid OOM (#823)
parent
ae5c0fc4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
17 deletions
+22
-17
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-8
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+8
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+5
-5
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
e7487b08
...
@@ -103,7 +103,7 @@ class RadixAttention(nn.Module):
...
@@ -103,7 +103,7 @@ class RadixAttention(nn.Module):
return
o
return
o
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
if
not
input_metadata
.
use_ragged
:
if
not
input_metadata
.
flashinfer_
use_ragged
:
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward
(
o
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e7487b08
...
@@ -781,7 +781,7 @@ class InputMetadata:
...
@@ -781,7 +781,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
use_ragged
:
bool
=
False
flashinfer_
use_ragged
:
bool
=
False
@
classmethod
@
classmethod
def
create
(
def
create
(
...
@@ -797,10 +797,10 @@ class InputMetadata:
...
@@ -797,10 +797,10 @@ class InputMetadata:
return_logprob
=
False
,
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
skip_flashinfer_init
=
False
,
):
):
use_ragged
=
False
flashinfer_
use_ragged
=
False
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
if
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
seq_lens
))
>
4096
:
if
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
seq_lens
))
>
4096
:
use_ragged
=
True
flashinfer_
use_ragged
=
True
init_flashinfer_args
(
init_flashinfer_args
(
forward_mode
,
forward_mode
,
model_runner
,
model_runner
,
...
@@ -808,7 +808,7 @@ class InputMetadata:
...
@@ -808,7 +808,7 @@ class InputMetadata:
seq_lens
,
seq_lens
,
prefix_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
,
model_runner
.
flashinfer_decode_wrapper
,
use_ragged
,
flashinfer_
use_ragged
,
)
)
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
...
@@ -863,7 +863,7 @@ class InputMetadata:
...
@@ -863,7 +863,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
use_ragged
=
use_ragged
,
flashinfer_use_ragged
=
flashinfer_
use_ragged
,
)
)
if
model_runner
.
server_args
.
disable_flashinfer
:
if
model_runner
.
server_args
.
disable_flashinfer
:
...
@@ -884,7 +884,7 @@ def init_flashinfer_args(
...
@@ -884,7 +884,7 @@ def init_flashinfer_args(
seq_lens
,
seq_lens
,
prefix_lens
,
prefix_lens
,
flashinfer_decode_wrapper
,
flashinfer_decode_wrapper
,
use_ragged
=
False
,
flashinfer_
use_ragged
=
False
,
):
):
"""Init auxiliary variables for FlashInfer attention backend."""
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
...
@@ -893,7 +893,7 @@ def init_flashinfer_args(
...
@@ -893,7 +893,7 @@ def init_flashinfer_args(
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
if
use_ragged
:
if
flashinfer_
use_ragged
:
paged_kernel_lens
=
prefix_lens
paged_kernel_lens
=
prefix_lens
else
:
else
:
paged_kernel_lens
=
seq_lens
paged_kernel_lens
=
seq_lens
...
@@ -929,7 +929,7 @@ def init_flashinfer_args(
...
@@ -929,7 +929,7 @@ def init_flashinfer_args(
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
if
use_ragged
:
if
flashinfer_
use_ragged
:
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e7487b08
...
@@ -212,9 +212,14 @@ class ModelRunner:
...
@@ -212,9 +212,14 @@ class ModelRunner:
)
)
if
max_num_reqs
is
None
:
if
max_num_reqs
is
None
:
max_num_reqs
=
max
(
max_num_reqs
=
min
(
int
(
self
.
max_total_num_tokens
/
self
.
model_config
.
context_len
*
512
),
max
(
int
(
self
.
max_total_num_tokens
/
self
.
model_config
.
context_len
*
512
),
2048
,
2048
,
),
5120
,
)
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
self
.
req_to_token_pool
=
ReqToTokenPool
(
...
...
python/sglang/srt/server_args.py
View file @
e7487b08
...
@@ -91,15 +91,15 @@ class ServerArgs:
...
@@ -91,15 +91,15 @@ class ServerArgs:
self
.
tokenizer_path
=
self
.
model_path
self
.
tokenizer_path
=
self
.
model_path
if
self
.
mem_fraction_static
is
None
:
if
self
.
mem_fraction_static
is
None
:
if
self
.
tp_size
>=
16
:
if
self
.
tp_size
>=
16
:
self
.
mem_fraction_static
=
0.
80
self
.
mem_fraction_static
=
0.
79
elif
self
.
tp_size
>=
8
:
elif
self
.
tp_size
>=
8
:
self
.
mem_fraction_static
=
0.8
4
self
.
mem_fraction_static
=
0.8
3
elif
self
.
tp_size
>=
4
:
elif
self
.
tp_size
>=
4
:
self
.
mem_fraction_static
=
0.8
6
self
.
mem_fraction_static
=
0.8
5
elif
self
.
tp_size
>=
2
:
elif
self
.
tp_size
>=
2
:
self
.
mem_fraction_static
=
0.8
8
self
.
mem_fraction_static
=
0.8
7
else
:
else
:
self
.
mem_fraction_static
=
0.8
9
self
.
mem_fraction_static
=
0.8
8
if
isinstance
(
self
.
additional_ports
,
int
):
if
isinstance
(
self
.
additional_ports
,
int
):
self
.
additional_ports
=
[
self
.
additional_ports
]
self
.
additional_ports
=
[
self
.
additional_ports
]
elif
self
.
additional_ports
is
None
:
elif
self
.
additional_ports
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment