Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cec98f10
Unverified
Commit
cec98f10
authored
May 09, 2025
by
yhyang201
Committed by
GitHub
May 08, 2025
Browse files
[Fix] Incorrect Memory Allocation on CUDA:0 by Non-Zero CUDA Processes in TP/DP (#5745)
parent
8dc4efd0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
1 deletion
+30
-1
python/sglang/srt/configs/deepseekvl2.py
python/sglang/srt/configs/deepseekvl2.py
+3
-0
python/sglang/srt/configs/janus_pro.py
python/sglang/srt/configs/janus_pro.py
+3
-0
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+5
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+19
-1
No files found.
python/sglang/srt/configs/deepseekvl2.py
View file @
cec98f10
...
...
@@ -48,6 +48,9 @@ class DictOutput(object):
def
__getitem__
(
self
,
item
):
return
self
.
__dict__
[
item
]
def
__contains__
(
self
,
key
):
return
key
in
self
.
__dict__
def
__setitem__
(
self
,
key
,
value
):
self
.
__dict__
[
key
]
=
value
...
...
python/sglang/srt/configs/janus_pro.py
View file @
cec98f10
...
...
@@ -290,6 +290,9 @@ class DictOutput(object):
def
__getitem__
(
self
,
item
):
return
self
.
__dict__
[
item
]
def
__contains__
(
self
,
key
):
return
key
in
self
.
__dict__
def
__setitem__
(
self
,
key
,
value
):
self
.
__dict__
[
key
]
=
value
...
...
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
cec98f10
...
...
@@ -8,6 +8,7 @@ from typing import List, Optional
import
numpy
as
np
import
PIL
import
torch
from
PIL
import
Image
from
transformers
import
BaseImageProcessorFast
...
...
@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
return_tensors
=
"pt"
,
**
kwargs
,
)
if
"pixel_values"
in
result
and
isinstance
(
result
[
"pixel_values"
],
torch
.
Tensor
):
result
[
"pixel_values"
]
=
result
[
"pixel_values"
].
to
(
"cpu"
)
return
result
@
abstractmethod
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
cec98f10
...
...
@@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
out_cache_loc
:
torch
.
Tensor
=
None
# shape: [b], int64
output_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
# For multimodal inputs
multimodal_inputs
:
Optional
[
List
]
=
None
# The sum of all sequence lengths
seq_lens_sum
:
int
=
None
...
...
@@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Copy prefix and do some basic check
input_embeds
=
[]
extend_input_logprob_token_ids
=
[]
multimodal_inputs
=
[]
for
i
,
(
req
,
seq_len
,
pre_len
)
in
enumerate
(
zip
(
reqs
,
seq_lens
,
prefix_lens
)):
req
.
req_pool_idx
=
req_pool_indices
[
i
]
...
...
@@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# If req.input_embeds is already a list, append its content directly
input_embeds
.
extend
(
req
.
input_embeds
)
# Use extend to avoid nesting
multimodal_inputs
.
append
(
req
.
multimodal_inputs
)
req
.
cached_tokens
+=
pre_len
-
req
.
already_computed
req
.
already_computed
=
seq_len
req
.
is_retracted
=
False
...
...
@@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
input_embeds
else
None
)
for
mm_input
in
multimodal_inputs
:
if
mm_input
is
None
:
continue
for
mm_item
in
mm_input
.
mm_items
:
pixel_values
=
getattr
(
mm_item
,
"pixel_values"
,
None
)
if
isinstance
(
pixel_values
,
torch
.
Tensor
):
mm_item
.
pixel_values
=
pixel_values
.
to
(
self
.
device
,
non_blocking
=
True
)
self
.
multimodal_inputs
=
multimodal_inputs
self
.
seq_lens_sum
=
sum
(
seq_lens
)
if
self
.
return_logprob
:
...
...
@@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
encoder_lens_cpu
=
[
self
.
encoder_lens_cpu
[
i
]
for
i
in
keep_indices
]
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
keep_indices
]
self
.
multimodal_inputs
=
[
self
.
multimodal_inputs
[
i
]
for
i
in
keep_indices
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
keep_indices_device
]
self
.
seq_lens
=
self
.
seq_lens
[
keep_indices_device
]
self
.
out_cache_loc
=
None
...
...
@@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
top_logprobs_nums
=
[
0
]
*
len
(
self
.
reqs
)
+
other
.
top_logprobs_nums
self
.
token_ids_logprobs
=
[
None
]
*
len
(
self
.
reqs
)
+
other
.
token_ids_logprobs
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
multimodal_inputs
.
extend
(
other
.
multimodal_inputs
)
self
.
return_logprob
|=
other
.
return_logprob
self
.
has_stream
|=
other
.
has_stream
...
...
@@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
multimodal_inputs
=
[
r
.
multimodal_inputs
for
r
in
self
.
reqs
]
,
multimodal_inputs
=
self
.
multimodal_inputs
,
encoder_cached
=
self
.
encoder_cached
,
encoder_lens
=
self
.
encoder_lens
,
encoder_lens_cpu
=
self
.
encoder_lens_cpu
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment