Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fa909dc3
Unverified
Commit
fa909dc3
authored
Apr 15, 2025
by
Yineng Zhang
Committed by
GitHub
Apr 15, 2025
Browse files
feat: update model_specific_adjustment (#5344)
Co-authored-by:
hebiao064
<
hebiaobuaa@gmail.com
>
parent
e8f62b20
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
17 deletions
+51
-17
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+8
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+16
-11
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+26
-1
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
fa909dc3
...
@@ -383,7 +383,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -383,7 +383,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
]
elif
forward_batch
.
forward_mode
.
is_extend_or_draft_extend
():
elif
forward_batch
.
forward_mode
.
is_extend_or_draft_extend
_or_mixed
():
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
fa909dc3
...
@@ -78,7 +78,7 @@ class ForwardMode(IntEnum):
...
@@ -78,7 +78,7 @@ class ForwardMode(IntEnum):
self
==
ForwardMode
.
EXTEND
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
self
.
TARGET_VERIFY
or
self
==
ForwardMode
.
TARGET_VERIFY
)
)
def
is_decode
(
self
):
def
is_decode
(
self
):
...
@@ -96,6 +96,13 @@ class ForwardMode(IntEnum):
...
@@ -96,6 +96,13 @@ class ForwardMode(IntEnum):
def
is_draft_extend
(
self
):
def
is_draft_extend
(
self
):
return
self
==
ForwardMode
.
DRAFT_EXTEND
return
self
==
ForwardMode
.
DRAFT_EXTEND
def
is_extend_or_draft_extend_or_mixed
(
self
):
return
(
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
MIXED
)
def
is_cuda_graph
(
self
):
def
is_cuda_graph
(
self
):
return
(
return
(
self
==
ForwardMode
.
DECODE
self
==
ForwardMode
.
DECODE
...
@@ -103,9 +110,6 @@ class ForwardMode(IntEnum):
...
@@ -103,9 +110,6 @@ class ForwardMode(IntEnum):
or
self
==
ForwardMode
.
IDLE
or
self
==
ForwardMode
.
IDLE
)
)
def
is_extend_or_draft_extend
(
self
):
return
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
def
is_dummy_first
(
self
):
def
is_dummy_first
(
self
):
return
self
==
ForwardMode
.
DUMMY_FIRST
return
self
==
ForwardMode
.
DUMMY_FIRST
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
fa909dc3
...
@@ -78,9 +78,11 @@ from sglang.srt.utils import (
...
@@ -78,9 +78,11 @@ from sglang.srt.utils import (
get_available_gpu_memory
,
get_available_gpu_memory
,
init_custom_process_group
,
init_custom_process_group
,
is_cuda
,
is_cuda
,
is_fa3_default_architecture
,
is_flashinfer_available
,
is_flashinfer_available
,
is_hip
,
is_hip
,
is_hopper_with_cuda_12_3
,
is_hopper_with_cuda_12_3
,
is_no_spec_infer_or_topk_one
,
monkey_patch_p2p_access_check
,
monkey_patch_p2p_access_check
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_gguf_config
,
set_cpu_offload_max_bytes
,
set_cpu_offload_max_bytes
,
...
@@ -242,18 +244,21 @@ class ModelRunner:
...
@@ -242,18 +244,21 @@ class ModelRunner:
elif
server_args
.
attention_backend
is
None
:
elif
server_args
.
attention_backend
is
None
:
# By default, use flashinfer for non-mla attention and triton for mla attention
# By default, use flashinfer for non-mla attention and triton for mla attention
if
not
self
.
use_mla_backend
:
if
not
self
.
use_mla_backend
:
server_args
.
attention_backend
=
(
if
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
is_hopper_with_cuda_12_3
()
)
and
is_no_spec_infer_or_topk_one
(
server_args
)
and
is_fa3_default_architecture
(
self
.
model_config
.
hf_config
)
):
server_args
.
attention_backend
=
"fa3"
else
:
server_args
.
attention_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
)
else
:
else
:
if
is_hopper_with_cuda_12_3
():
if
is_hopper_with_cuda_12_3
()
and
is_no_spec_infer_or_topk_one
(
if
server_args
.
speculative_eagle_topk
is
None
or
(
server_args
server_args
.
speculative_eagle_topk
is
not
None
):
and
server_args
.
speculative_eagle_topk
==
1
server_args
.
attention_backend
=
"fa3"
):
server_args
.
attention_backend
=
"fa3"
else
:
server_args
.
attention_backend
=
"triton"
else
:
else
:
server_args
.
attention_backend
=
"triton"
server_args
.
attention_backend
=
"triton"
logger
.
info
(
logger
.
info
(
...
...
python/sglang/srt/utils.py
View file @
fa909dc3
...
@@ -569,7 +569,7 @@ def encode_video(video_path, frame_count_limit=None):
...
@@ -569,7 +569,7 @@ def encode_video(video_path, frame_count_limit=None):
def
load_image
(
def
load_image
(
image_file
:
Union
[
Image
.
Image
,
str
,
bytes
]
image_file
:
Union
[
Image
.
Image
,
str
,
bytes
]
,
)
->
tuple
[
Image
.
Image
,
tuple
[
int
,
int
]]:
)
->
tuple
[
Image
.
Image
,
tuple
[
int
,
int
]]:
image
=
image_size
=
None
image
=
image_size
=
None
if
isinstance
(
image_file
,
Image
.
Image
):
if
isinstance
(
image_file
,
Image
.
Image
):
...
@@ -1905,3 +1905,28 @@ def get_local_ip_by_remote() -> str:
...
@@ -1905,3 +1905,28 @@ def get_local_ip_by_remote() -> str:
return
s
.
getsockname
()[
0
]
return
s
.
getsockname
()[
0
]
except
Exception
:
except
Exception
:
raise
ValueError
(
f
"Can not get local ip"
)
raise
ValueError
(
f
"Can not get local ip"
)
def
is_page_size_one
(
server_args
):
return
server_args
.
page_size
==
1
def
is_no_spec_infer_or_topk_one
(
server_args
):
return
server_args
.
speculative_eagle_topk
is
None
or
(
server_args
.
speculative_eagle_topk
is
not
None
and
server_args
.
speculative_eagle_topk
==
1
and
is_page_size_one
(
server_args
)
)
def
is_fa3_default_architecture
(
hf_config
):
architectures
=
getattr
(
hf_config
,
"architectures"
,
None
)
if
not
isinstance
(
architectures
,
list
)
or
not
architectures
:
return
False
default_archs
=
{
"Qwen2ForCausalLM"
,
"Llama4ForConditionalGeneration"
,
"LlamaForCausalLM"
,
"MistralForCausalLM"
,
}
return
architectures
[
0
]
in
default_archs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment