Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d4c038da
Unverified
Commit
d4c038da
authored
May 21, 2025
by
Baizhou Zhang
Committed by
GitHub
May 21, 2025
Browse files
[Fix]Fix capture fail bug for DeepSeek (#6275)
parent
55f6005f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
13 deletions
+20
-13
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+3
-2
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+11
-8
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-1
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+3
-2
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
d4c038da
...
...
@@ -266,7 +266,6 @@ class MHATokenToKVPool(KVCache):
self
.
_create_buffers
()
self
.
layer_transfer_counter
=
None
self
.
capture_mode
=
False
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
self
.
alt_stream
=
self
.
device_module
.
Stream
()
if
is_cuda
else
None
...
...
@@ -385,6 +384,8 @@ class MHATokenToKVPool(KVCache):
k_scale
:
Optional
[
float
]
=
None
,
v_scale
:
Optional
[
float
]
=
None
,
):
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
layer_id
=
layer
.
layer_id
if
cache_k
.
dtype
!=
self
.
dtype
:
if
k_scale
is
not
None
:
...
...
@@ -398,7 +399,7 @@ class MHATokenToKVPool(KVCache):
cache_k
=
cache_k
.
view
(
self
.
store_dtype
)
cache_v
=
cache_v
.
view
(
self
.
store_dtype
)
if
self
.
capture_mode
and
self
.
alt_stream
is
not
None
:
if
get_is_
capture_mode
()
and
self
.
alt_stream
is
not
None
:
# Overlap the copy of K and V cache for small batch size
current_stream
=
self
.
device_module
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
d4c038da
...
...
@@ -47,6 +47,13 @@ from sglang.srt.utils import (
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
# Detect whether the current forward pass is in capture mode
is_capture_mode
=
False
def
get_is_capture_mode
():
return
is_capture_mode
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
for
sub
in
model
.
_modules
.
values
():
...
...
@@ -311,17 +318,12 @@ class CudaGraphRunner:
@
contextmanager
def
model_capture_mode
(
self
):
if
hasattr
(
self
.
model_runner
.
model
,
"capture_mode"
):
self
.
model_runner
.
model
.
capture_mode
=
True
if
hasattr
(
self
.
model_runner
.
token_to_kv_pool
,
"capture_mode"
):
self
.
model_runner
.
token_to_kv_pool
.
capture_mode
=
True
global
is_capture_mode
is_capture_mode
=
True
yield
if
hasattr
(
self
.
model_runner
.
model
,
"capture_mode"
):
self
.
model_runner
.
model
.
capture_mode
=
False
if
hasattr
(
self
.
model_runner
.
token_to_kv_pool
,
"capture_mode"
):
self
.
model_runner
.
token_to_kv_pool
.
capture_mode
=
False
is_capture_mode
=
False
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
...
...
@@ -612,6 +614,7 @@ class CudaGraphRunner:
# Replay
self
.
graphs
[
self
.
bs
].
replay
()
output
=
self
.
output_buffers
[
self
.
bs
]
if
isinstance
(
output
,
LogitsProcessorOutput
):
return
LogitsProcessorOutput
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
d4c038da
...
...
@@ -754,6 +754,8 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
if
self
.
q_lora_rank
is
not
None
:
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
...
...
@@ -761,7 +763,7 @@ class DeepseekV2AttentionMLA(nn.Module):
k_nope
=
latent_cache
[...,
:
self
.
kv_lora_rank
]
# overlap qk norm
if
self
.
alt_stream
is
not
None
and
torch
.
cuda
.
is_current_stream
_captur
ing
():
if
self
.
alt_stream
is
not
None
and
get_is
_captur
e_mode
():
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
q
=
self
.
q_a_layernorm
(
q
)
...
...
python/sglang/srt/models/mllama.py
View file @
d4c038da
...
...
@@ -836,7 +836,6 @@ class MllamaForConditionalGeneration(nn.Module):
prefix
=
"multi_modal_projector"
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
)
self
.
capture_mode
=
False
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
pixel_values
=
torch
.
cat
(
...
...
@@ -969,6 +968,8 @@ class MllamaForConditionalGeneration(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
batched_images
,
batched_ar_ids
,
batched_ar_mask
,
encoder_lens_need
=
(
self
.
_batch_image_inputs
(
forward_batch
)
)
...
...
@@ -977,7 +978,7 @@ class MllamaForConditionalGeneration(nn.Module):
cross_attention_mask
=
None
cross_attention_states
=
None
if
self
.
capture_mode
:
if
get_is_
capture_mode
()
:
# NOTE: when doing cuda graph capture, we do not want to skip cross attention
# Make is a constant value to avoid cuda graph capture issue
skip_cross_attention
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment