Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e7487b08
"docs/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "0daffad3b3aad3fec35b2068b82120df4b797351"
Unverified
Commit
e7487b08
authored
Jul 30, 2024
by
Ying Sheng
Committed by
GitHub
Jul 30, 2024
Browse files
Adjust default mem fraction to avoid OOM (#823)
parent
ae5c0fc4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
17 deletions
+22
-17
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-8
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+8
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+5
-5
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
e7487b08
...
@@ -103,7 +103,7 @@ class RadixAttention(nn.Module):
...
@@ -103,7 +103,7 @@ class RadixAttention(nn.Module):
return
o
return
o
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
if
not
input_metadata
.
use_ragged
:
if
not
input_metadata
.
flashinfer_
use_ragged
:
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward
(
o
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e7487b08
...
@@ -781,7 +781,7 @@ class InputMetadata:
...
@@ -781,7 +781,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
use_ragged
:
bool
=
False
flashinfer_
use_ragged
:
bool
=
False
@
classmethod
@
classmethod
def
create
(
def
create
(
...
@@ -797,10 +797,10 @@ class InputMetadata:
...
@@ -797,10 +797,10 @@ class InputMetadata:
return_logprob
=
False
,
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
skip_flashinfer_init
=
False
,
):
):
use_ragged
=
False
flashinfer_
use_ragged
=
False
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
if
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
seq_lens
))
>
4096
:
if
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
seq_lens
))
>
4096
:
use_ragged
=
True
flashinfer_
use_ragged
=
True
init_flashinfer_args
(
init_flashinfer_args
(
forward_mode
,
forward_mode
,
model_runner
,
model_runner
,
...
@@ -808,7 +808,7 @@ class InputMetadata:
...
@@ -808,7 +808,7 @@ class InputMetadata:
seq_lens
,
seq_lens
,
prefix_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
,
model_runner
.
flashinfer_decode_wrapper
,
use_ragged
,
flashinfer_
use_ragged
,
)
)
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
...
@@ -863,7 +863,7 @@ class InputMetadata:
...
@@ -863,7 +863,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
use_ragged
=
use_ragged
,
flashinfer_use_ragged
=
flashinfer_
use_ragged
,
)
)
if
model_runner
.
server_args
.
disable_flashinfer
:
if
model_runner
.
server_args
.
disable_flashinfer
:
...
@@ -884,7 +884,7 @@ def init_flashinfer_args(
...
@@ -884,7 +884,7 @@ def init_flashinfer_args(
seq_lens
,
seq_lens
,
prefix_lens
,
prefix_lens
,
flashinfer_decode_wrapper
,
flashinfer_decode_wrapper
,
use_ragged
=
False
,
flashinfer_
use_ragged
=
False
,
):
):
"""Init auxiliary variables for FlashInfer attention backend."""
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
...
@@ -893,7 +893,7 @@ def init_flashinfer_args(
...
@@ -893,7 +893,7 @@ def init_flashinfer_args(
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
if
use_ragged
:
if
flashinfer_
use_ragged
:
paged_kernel_lens
=
prefix_lens
paged_kernel_lens
=
prefix_lens
else
:
else
:
paged_kernel_lens
=
seq_lens
paged_kernel_lens
=
seq_lens
...
@@ -929,7 +929,7 @@ def init_flashinfer_args(
...
@@ -929,7 +929,7 @@ def init_flashinfer_args(
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
if
use_ragged
:
if
flashinfer_
use_ragged
:
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e7487b08
...
@@ -212,9 +212,14 @@ class ModelRunner:
...
@@ -212,9 +212,14 @@ class ModelRunner:
)
)
if
max_num_reqs
is
None
:
if
max_num_reqs
is
None
:
max_num_reqs
=
max
(
max_num_reqs
=
min
(
int
(
self
.
max_total_num_tokens
/
self
.
model_config
.
context_len
*
512
),
max
(
2048
,
int
(
self
.
max_total_num_tokens
/
self
.
model_config
.
context_len
*
512
),
2048
,
),
5120
,
)
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
self
.
req_to_token_pool
=
ReqToTokenPool
(
...
...
python/sglang/srt/server_args.py
View file @
e7487b08
...
@@ -91,15 +91,15 @@ class ServerArgs:
...
@@ -91,15 +91,15 @@ class ServerArgs:
self
.
tokenizer_path
=
self
.
model_path
self
.
tokenizer_path
=
self
.
model_path
if
self
.
mem_fraction_static
is
None
:
if
self
.
mem_fraction_static
is
None
:
if
self
.
tp_size
>=
16
:
if
self
.
tp_size
>=
16
:
self
.
mem_fraction_static
=
0.
80
self
.
mem_fraction_static
=
0.
79
elif
self
.
tp_size
>=
8
:
elif
self
.
tp_size
>=
8
:
self
.
mem_fraction_static
=
0.8
4
self
.
mem_fraction_static
=
0.8
3
elif
self
.
tp_size
>=
4
:
elif
self
.
tp_size
>=
4
:
self
.
mem_fraction_static
=
0.8
6
self
.
mem_fraction_static
=
0.8
5
elif
self
.
tp_size
>=
2
:
elif
self
.
tp_size
>=
2
:
self
.
mem_fraction_static
=
0.8
8
self
.
mem_fraction_static
=
0.8
7
else
:
else
:
self
.
mem_fraction_static
=
0.8
9
self
.
mem_fraction_static
=
0.8
8
if
isinstance
(
self
.
additional_ports
,
int
):
if
isinstance
(
self
.
additional_ports
,
int
):
self
.
additional_ports
=
[
self
.
additional_ports
]
self
.
additional_ports
=
[
self
.
additional_ports
]
elif
self
.
additional_ports
is
None
:
elif
self
.
additional_ports
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment