Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
78b7465c
Unverified
Commit
78b7465c
authored
Sep 13, 2025
by
kk
Committed by
GitHub
Sep 12, 2025
Browse files
Fix GPU fault issue when run dsv3 with dp mode and enable torch-compile (#10361)
Co-authored-by:
wunhuang
<
wunhuang@amd.com
>
parent
07bcad7f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
5 deletions
+39
-5
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+24
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+15
-5
No files found.
python/sglang/srt/layers/dp_attention.py
View file @
78b7465c
...
...
@@ -119,6 +119,18 @@ class _DpGatheredBufferWrapper:
def
get_dp_global_num_tokens
(
cls
)
->
List
[
int
]:
return
cls
.
_global_num_tokens
@
classmethod
def
get_dp_hidden_size
(
cls
)
->
int
:
return
cls
.
_hidden_size
@
classmethod
def
get_dp_dtype
(
cls
)
->
torch
.
dtype
:
return
cls
.
_dtype
@
classmethod
def
get_dp_device
(
cls
)
->
torch
.
device
:
return
cls
.
_device
def
set_dp_buffer_len
(
global_dp_buffer_len
:
int
,
...
...
@@ -150,6 +162,18 @@ def get_dp_global_num_tokens() -> List[int]:
return
_DpGatheredBufferWrapper
.
get_dp_global_num_tokens
()
def
get_dp_hidden_size
()
->
int
:
return
_DpGatheredBufferWrapper
.
get_dp_hidden_size
()
def
get_dp_dtype
()
->
torch
.
dtype
:
return
_DpGatheredBufferWrapper
.
get_dp_dtype
()
def
get_dp_device
()
->
torch
.
device
:
return
_DpGatheredBufferWrapper
.
get_dp_device
()
def
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
if
not
enable_dp_attention
:
return
tp_rank
,
tp_size
,
0
...
...
python/sglang/srt/layers/logits_processor.py
View file @
78b7465c
...
...
@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_rank
,
get_attention_dp_size
,
get_attention_tp_size
,
get_dp_device
,
get_dp_dtype
,
get_dp_hidden_size
,
get_global_dp_buffer
,
get_local_attention_dp_size
,
set_dp_buffer_len
,
...
...
@@ -187,16 +190,23 @@ class LogitsMetadata:
self
.
dp_local_start_pos
=
dp_local_start_pos
self
.
dp_local_num_tokens
=
dp_local_num_tokens
hidden_size
=
get_dp_hidden_size
()
dtype
=
get_dp_dtype
()
device
=
get_dp_device
()
if
self
.
global_num_tokens_for_logprob_cpu
is
not
None
:
# create a smaller buffer to reduce peak memory usage
self
.
global_dp_buffer_len
=
sum
(
self
.
global_num_tokens_for_logprob_cpu
)
else
:
self
.
global_dp_buffer_len
=
self
.
global_dp_buffer_len
set_dp_buffer_len
(
self
.
global_dp_buffer_len
,
self
.
dp_local_num_tokens
,
self
.
global_num_tokens_for_logprob_cpu
,
self
.
gathered_buffer
=
torch
.
empty
(
(
self
.
global_dp_buffer_len
,
hidden_size
,
),
dtype
=
dtype
,
device
=
device
,
)
...
...
@@ -443,7 +453,7 @@ class LogitsProcessor(nn.Module):
if
self
.
do_tensor_parallel_all_gather_dp_attn
:
logits_metadata
.
compute_dp_attention_metadata
()
hidden_states
,
local_hidden_states
=
(
get_global_dp
_buffer
()
,
logits_metadata
.
gathered
_buffer
,
hidden_states
,
)
dp_gather_replicate
(
hidden_states
,
local_hidden_states
,
logits_metadata
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment