Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cec98f10
Unverified
Commit
cec98f10
authored
May 09, 2025
by
yhyang201
Committed by
GitHub
May 08, 2025
Browse files
[Fix] Incorrect Memory Allocation on CUDA:0 by Non-Zero CUDA Processes in TP/DP (#5745)
parent
8dc4efd0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
1 deletion
+30
-1
python/sglang/srt/configs/deepseekvl2.py
python/sglang/srt/configs/deepseekvl2.py
+3
-0
python/sglang/srt/configs/janus_pro.py
python/sglang/srt/configs/janus_pro.py
+3
-0
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+5
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+19
-1
No files found.
python/sglang/srt/configs/deepseekvl2.py
View file @
cec98f10
...
@@ -48,6 +48,9 @@ class DictOutput(object):
...
@@ -48,6 +48,9 @@ class DictOutput(object):
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
self
.
__dict__
[
item
]
return
self
.
__dict__
[
item
]
def
__contains__
(
self
,
key
):
return
key
in
self
.
__dict__
def
__setitem__
(
self
,
key
,
value
):
def
__setitem__
(
self
,
key
,
value
):
self
.
__dict__
[
key
]
=
value
self
.
__dict__
[
key
]
=
value
...
...
python/sglang/srt/configs/janus_pro.py
View file @
cec98f10
...
@@ -290,6 +290,9 @@ class DictOutput(object):
...
@@ -290,6 +290,9 @@ class DictOutput(object):
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
self
.
__dict__
[
item
]
return
self
.
__dict__
[
item
]
def
__contains__
(
self
,
key
):
return
key
in
self
.
__dict__
def
__setitem__
(
self
,
key
,
value
):
def
__setitem__
(
self
,
key
,
value
):
self
.
__dict__
[
key
]
=
value
self
.
__dict__
[
key
]
=
value
...
...
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
cec98f10
...
@@ -8,6 +8,7 @@ from typing import List, Optional
...
@@ -8,6 +8,7 @@ from typing import List, Optional
import
numpy
as
np
import
numpy
as
np
import
PIL
import
PIL
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
BaseImageProcessorFast
from
transformers
import
BaseImageProcessorFast
...
@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
...
@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
**
kwargs
,
**
kwargs
,
)
)
if
"pixel_values"
in
result
and
isinstance
(
result
[
"pixel_values"
],
torch
.
Tensor
):
result
[
"pixel_values"
]
=
result
[
"pixel_values"
].
to
(
"cpu"
)
return
result
return
result
@
abstractmethod
@
abstractmethod
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
cec98f10
...
@@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
out_cache_loc
:
torch
.
Tensor
=
None
# shape: [b], int64
out_cache_loc
:
torch
.
Tensor
=
None
# shape: [b], int64
output_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
output_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
# For multimodal inputs
multimodal_inputs
:
Optional
[
List
]
=
None
# The sum of all sequence lengths
# The sum of all sequence lengths
seq_lens_sum
:
int
=
None
seq_lens_sum
:
int
=
None
...
@@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Copy prefix and do some basic check
# Copy prefix and do some basic check
input_embeds
=
[]
input_embeds
=
[]
extend_input_logprob_token_ids
=
[]
extend_input_logprob_token_ids
=
[]
multimodal_inputs
=
[]
for
i
,
(
req
,
seq_len
,
pre_len
)
in
enumerate
(
zip
(
reqs
,
seq_lens
,
prefix_lens
)):
for
i
,
(
req
,
seq_len
,
pre_len
)
in
enumerate
(
zip
(
reqs
,
seq_lens
,
prefix_lens
)):
req
.
req_pool_idx
=
req_pool_indices
[
i
]
req
.
req_pool_idx
=
req_pool_indices
[
i
]
...
@@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# If req.input_embeds is already a list, append its content directly
# If req.input_embeds is already a list, append its content directly
input_embeds
.
extend
(
req
.
input_embeds
)
# Use extend to avoid nesting
input_embeds
.
extend
(
req
.
input_embeds
)
# Use extend to avoid nesting
multimodal_inputs
.
append
(
req
.
multimodal_inputs
)
req
.
cached_tokens
+=
pre_len
-
req
.
already_computed
req
.
cached_tokens
+=
pre_len
-
req
.
already_computed
req
.
already_computed
=
seq_len
req
.
already_computed
=
seq_len
req
.
is_retracted
=
False
req
.
is_retracted
=
False
...
@@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
input_embeds
if
input_embeds
else
None
else
None
)
)
for
mm_input
in
multimodal_inputs
:
if
mm_input
is
None
:
continue
for
mm_item
in
mm_input
.
mm_items
:
pixel_values
=
getattr
(
mm_item
,
"pixel_values"
,
None
)
if
isinstance
(
pixel_values
,
torch
.
Tensor
):
mm_item
.
pixel_values
=
pixel_values
.
to
(
self
.
device
,
non_blocking
=
True
)
self
.
multimodal_inputs
=
multimodal_inputs
self
.
seq_lens_sum
=
sum
(
seq_lens
)
self
.
seq_lens_sum
=
sum
(
seq_lens
)
if
self
.
return_logprob
:
if
self
.
return_logprob
:
...
@@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
encoder_lens_cpu
=
[
self
.
encoder_lens_cpu
[
i
]
for
i
in
keep_indices
]
self
.
encoder_lens_cpu
=
[
self
.
encoder_lens_cpu
[
i
]
for
i
in
keep_indices
]
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
keep_indices
]
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
keep_indices
]
self
.
multimodal_inputs
=
[
self
.
multimodal_inputs
[
i
]
for
i
in
keep_indices
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
keep_indices_device
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
keep_indices_device
]
self
.
seq_lens
=
self
.
seq_lens
[
keep_indices_device
]
self
.
seq_lens
=
self
.
seq_lens
[
keep_indices_device
]
self
.
out_cache_loc
=
None
self
.
out_cache_loc
=
None
...
@@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
top_logprobs_nums
=
[
0
]
*
len
(
self
.
reqs
)
+
other
.
top_logprobs_nums
self
.
top_logprobs_nums
=
[
0
]
*
len
(
self
.
reqs
)
+
other
.
top_logprobs_nums
self
.
token_ids_logprobs
=
[
None
]
*
len
(
self
.
reqs
)
+
other
.
token_ids_logprobs
self
.
token_ids_logprobs
=
[
None
]
*
len
(
self
.
reqs
)
+
other
.
token_ids_logprobs
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
multimodal_inputs
.
extend
(
other
.
multimodal_inputs
)
self
.
return_logprob
|=
other
.
return_logprob
self
.
return_logprob
|=
other
.
return_logprob
self
.
has_stream
|=
other
.
has_stream
self
.
has_stream
|=
other
.
has_stream
...
@@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
multimodal_inputs
=
[
r
.
multimodal_inputs
for
r
in
self
.
reqs
]
,
multimodal_inputs
=
self
.
multimodal_inputs
,
encoder_cached
=
self
.
encoder_cached
,
encoder_cached
=
self
.
encoder_cached
,
encoder_lens
=
self
.
encoder_lens
,
encoder_lens
=
self
.
encoder_lens
,
encoder_lens_cpu
=
self
.
encoder_lens_cpu
,
encoder_lens_cpu
=
self
.
encoder_lens_cpu
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment