Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
08c4d764
Unverified
Commit
08c4d764
authored
Mar 08, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 08, 2025
Browse files
lazy import attn backends (#4200)
parent
96d0e37f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
11 deletions
+21
-11
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+1
-3
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+18
-6
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+1
-1
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
08c4d764
...
@@ -6,9 +6,7 @@ import torch
...
@@ -6,9 +6,7 @@ import torch
import
triton
import
triton
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
create_flashinfer_kv_indices_triton
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
08c4d764
...
@@ -302,7 +302,7 @@ class CudaGraphRunner:
...
@@ -302,7 +302,7 @@ class CudaGraphRunner:
self
.
stream
=
graph_capture_context
.
stream
self
.
stream
=
graph_capture_context
.
stream
# Reverse the order to enable better memory sharing across cuda graphs.
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range
=
(
capture_range
=
(
tqdm
.
tqdm
(
reversed
(
self
.
capture_bs
))
tqdm
.
tqdm
(
list
(
reversed
(
self
.
capture_bs
))
)
if
get_tensor_model_parallel_rank
()
==
0
if
get_tensor_model_parallel_rank
()
==
0
else
reversed
(
self
.
capture_bs
)
else
reversed
(
self
.
capture_bs
)
)
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
08c4d764
...
@@ -35,11 +35,6 @@ from sglang.srt.distributed import (
...
@@ -35,11 +35,6 @@ from sglang.srt.distributed import (
set_custom_all_reduce
,
set_custom_all_reduce
,
)
)
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.layers.attention.double_sparsity_backend
import
DoubleSparseAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
FlashInferMLAAttnBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_group
,
get_attention_tp_group
,
get_attention_tp_size
,
get_attention_tp_size
,
...
@@ -77,7 +72,6 @@ from sglang.srt.utils import (
...
@@ -77,7 +72,6 @@ from sglang.srt.utils import (
set_cpu_offload_max_bytes
,
set_cpu_offload_max_bytes
,
set_cuda_arch
,
set_cuda_arch
,
)
)
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -779,6 +773,10 @@ class ModelRunner:
...
@@ -779,6 +773,10 @@ class ModelRunner:
def
init_attention_backend
(
self
):
def
init_attention_backend
(
self
):
"""Init attention kernel backend."""
"""Init attention kernel backend."""
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
)
# Init streams
# Init streams
if
self
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
if
self
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
self
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
self
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
...
@@ -794,12 +792,26 @@ class ModelRunner:
...
@@ -794,12 +792,26 @@ class ModelRunner:
"Please use `--attention-backend flashinfer`."
"Please use `--attention-backend flashinfer`."
)
)
if
self
.
server_args
.
enable_double_sparsity
:
if
self
.
server_args
.
enable_double_sparsity
:
from
sglang.srt.layers.attention.double_sparsity_backend
import
(
DoubleSparseAttnBackend
,
)
self
.
attn_backend
=
DoubleSparseAttnBackend
(
self
)
self
.
attn_backend
=
DoubleSparseAttnBackend
(
self
)
else
:
else
:
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
self
.
attn_backend
=
TritonAttnBackend
(
self
)
self
.
attn_backend
=
TritonAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"torch_native"
:
elif
self
.
server_args
.
attention_backend
==
"torch_native"
:
from
sglang.srt.layers.attention.torch_native_backend
import
(
TorchNativeAttnBackend
,
)
self
.
attn_backend
=
TorchNativeAttnBackend
(
self
)
self
.
attn_backend
=
TorchNativeAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"flashinfer_mla"
:
elif
self
.
server_args
.
attention_backend
==
"flashinfer_mla"
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
)
self
.
attn_backend
=
FlashInferMLAAttnBackend
(
self
)
self
.
attn_backend
=
FlashInferMLAAttnBackend
(
self
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
...
test/srt/test_eagle_infer.py
View file @
08c4d764
...
@@ -108,7 +108,7 @@ class TestEAGLEEngine(unittest.TestCase):
...
@@ -108,7 +108,7 @@ class TestEAGLEEngine(unittest.TestCase):
def
_test_eos_token
(
self
,
engine
):
def
_test_eos_token
(
self
,
engine
):
prompt
=
"[INST] <<SYS>>
\n
You are a helpful assistant.
\n
<</SYS>>
\n
Today is a sunny day and I like [/INST]"
prompt
=
"[INST] <<SYS>>
\n
You are a helpful assistant.
\n
<</SYS>>
\n
Today is a sunny day and I like [/INST]"
params
=
{
params
=
{
"temperature"
:
0
,
"temperature"
:
0
.1
,
"max_new_tokens"
:
1024
,
"max_new_tokens"
:
1024
,
"skip_special_tokens"
:
False
,
"skip_special_tokens"
:
False
,
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment