Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
01dd39ba
Unverified
Commit
01dd39ba
authored
May 18, 2025
by
Mick
Committed by
GitHub
May 17, 2025
Browse files
refactor: minor refactors regarding multimodal processing (#6187)
parent
b3f3d610
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
140 additions
and
98 deletions
+140
-98
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+13
-28
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+46
-6
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+1
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+9
-1
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+19
-4
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+19
-15
python/sglang/srt/managers/multimodal_processors/minicpm.py
python/sglang/srt/managers/multimodal_processors/minicpm.py
+0
-28
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-2
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+10
-8
python/sglang/srt/mm_utils.py
python/sglang/srt/mm_utils.py
+10
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
python/sglang/srt/models/minicpmo.py
python/sglang/srt/models/minicpmo.py
+5
-2
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+0
-1
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+2
-0
test/srt/test_vision_chunked_prefill.py
test/srt/test_vision_chunked_prefill.py
+1
-1
No files found.
python/sglang/srt/configs/model_config.py
View file @
01dd39ba
...
...
@@ -22,7 +22,11 @@ from typing import List, Optional, Set, Union
import
torch
from
transformers
import
PretrainedConfig
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
sglang.srt.hf_transformers_utils
import
(
get_config
,
get_context_length
,
get_hf_text_config
,
)
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
...
...
@@ -209,7 +213,13 @@ class ModelConfig:
# Cache attributes
self
.
hf_eos_token_id
=
self
.
get_hf_eos_token_id
()
self
.
image_token_id
=
getattr
(
self
.
hf_config
,
"image_token_id"
,
None
)
config
=
self
.
hf_config
# multimodal
self
.
image_token_id
=
getattr
(
config
,
"image_token_id"
,
None
)
or
getattr
(
config
,
"image_token_index"
,
None
)
@
staticmethod
def
from_server_args
(
server_args
:
ServerArgs
,
model_path
:
str
=
None
,
**
kwargs
):
...
...
@@ -423,31 +433,6 @@ class ModelConfig:
self
.
model_path
=
client
.
get_local_dir
()
def
get_hf_text_config
(
config
:
PretrainedConfig
):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
class_name
=
config
.
architectures
[
0
]
if
class_name
.
startswith
(
"Llava"
)
and
class_name
.
endswith
(
"ForCausalLM"
):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr
(
config
,
"torch_dtype"
,
torch
.
float16
)
return
config
if
hasattr
(
config
,
"text_config"
):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert
hasattr
(
config
.
text_config
,
"num_attention_heads"
)
return
config
.
text_config
if
hasattr
(
config
,
"language_config"
):
return
config
.
language_config
else
:
return
config
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
float16
,
...
...
@@ -537,6 +522,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
multimodal_model_archs
=
[
"CLIPModel"
,
"DeepseekVL2ForCausalLM"
,
"Gemma3ForConditionalGeneration"
,
"Grok1VForCausalLM"
,
...
...
@@ -554,7 +540,6 @@ multimodal_model_archs = [
"MllamaForConditionalGeneration"
,
"Qwen2VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"CLIPModel"
,
"KimiVLForConditionalGeneration"
,
"InternVLChatModel"
,
]
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
01dd39ba
...
...
@@ -19,6 +19,7 @@ import warnings
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Type
,
Union
import
torch
from
huggingface_hub
import
snapshot_download
from
transformers
import
(
AutoConfig
,
...
...
@@ -65,6 +66,43 @@ def download_from_hf(model_path: str):
return
snapshot_download
(
model_path
,
allow_patterns
=
[
"*.json"
,
"*.bin"
,
"*.model"
])
def
get_hf_text_config
(
config
:
PretrainedConfig
):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
if
config
.
architectures
is
not
None
:
class_name
=
config
.
architectures
[
0
]
if
class_name
.
startswith
(
"Llava"
)
and
class_name
.
endswith
(
"ForCausalLM"
):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr
(
config
,
"torch_dtype"
,
torch
.
float16
)
return
config
if
hasattr
(
config
,
"text_config"
):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert
hasattr
(
config
.
text_config
,
"num_attention_heads"
)
return
config
.
text_config
if
hasattr
(
config
,
"language_config"
):
return
config
.
language_config
if
hasattr
(
config
,
"thinker_config"
):
# qwen2.5 omni
thinker_config
=
config
.
thinker_config
if
hasattr
(
thinker_config
,
"text_config"
):
setattr
(
thinker_config
.
text_config
,
"torch_dtype"
,
getattr
(
thinker_config
,
"torch_dtype"
,
None
),
)
return
thinker_config
.
text_config
return
thinker_config
else
:
return
config
def
get_config
(
model
:
str
,
trust_remote_code
:
bool
,
...
...
@@ -80,13 +118,12 @@ def get_config(
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
text_config
=
get_hf_text_config
(
config
=
config
)
# FIXME: Pour contents of janus-pro's langauge_config to first-level
if
isinstance
(
model
,
str
)
and
model
.
lower
().
startswith
(
"deepseek-ai/janus-pro"
):
assert
hasattr
(
config
,
"language_config"
)
for
key
,
val
in
config
.
language_config
.
__dict__
.
items
():
if
isinstance
(
model
,
str
)
and
text_config
is
not
None
:
for
key
,
val
in
text_config
.
__dict__
.
items
():
if
not
hasattr
(
config
,
key
)
and
getattr
(
text_config
,
key
,
None
)
is
not
None
:
setattr
(
config
,
key
,
val
)
setattr
(
config
,
"architectures"
,
[
"MultiModalityCausalLM"
])
if
config
.
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
...
...
@@ -99,6 +136,9 @@ def get_config(
if
not
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
val
)
if
config
.
model_type
==
"multi_modality"
:
config
.
update
({
"architectures"
:
[
"MultiModalityCausalLM"
]})
if
model_override_args
:
config
.
update
(
model_override_args
)
...
...
python/sglang/srt/layers/attention/vision.py
View file @
01dd39ba
...
...
@@ -120,7 +120,7 @@ class VisionSdpaAttention(nn.Module):
flatten_batch
:
bool
=
False
,
)
->
Optional
[
torch
.
Tensor
]:
r
"""
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`.
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1,
1,
s, s)`.
Args:
s: sequence length
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
...
...
python/sglang/srt/managers/io_struct.py
View file @
01dd39ba
...
...
@@ -22,13 +22,15 @@ from dataclasses import dataclass, field
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
from
sglang.srt.mm_utils
import
has_valid_data
# handle serialization of Image for pydantic
if
TYPE_CHECKING
:
from
PIL.Image
import
Image
else
:
Image
=
Any
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
,
flatten_nested_list
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
...
@@ -104,6 +106,9 @@ class GenerateReqInput:
bootstrap_port
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
bootstrap_room
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
def
contains_mm_input
(
self
)
->
bool
:
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
def
normalize_batch_and_arguments
(
self
):
"""
Normalize the batch size and arguments for the request.
...
...
@@ -487,6 +492,9 @@ class EmbeddingReqInput:
# The modalities of the image data [image, multi-images, video]
modalities
:
Optional
[
List
[
str
]]
=
None
def
contains_mm_input
(
self
)
->
bool
:
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
def
normalize_batch_and_arguments
(
self
):
# at least one of text, input_ids, or image should be provided
if
self
.
text
is
None
and
self
.
input_ids
is
None
and
self
.
image_data
is
None
:
...
...
python/sglang/srt/managers/mm_utils.py
View file @
01dd39ba
...
...
@@ -2,6 +2,7 @@
Multi-modality utils
"""
import
dataclasses
import
logging
from
abc
import
abstractmethod
from
typing
import
Callable
,
List
,
Optional
,
Tuple
...
...
@@ -41,11 +42,26 @@ class MultiModalityDataPaddingPattern:
class
MultiModalityDataPaddingPatternTokenPairs
(
MultiModalityDataPaddingPattern
):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
"""
def
__init__
(
self
,
data_token_pairs
:
Optional
[
List
[
Tuple
[
int
,
int
]]])
->
None
:
def
__init__
(
self
,
data_token_pairs
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
data_start_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
"""
Args:
data_start_token_ids marks the start of a single multimodal data
See Minicpmo's slice_start_id for example
"""
self
.
data_token_id_pairs
=
data_token_pairs
self
.
data_start_token_ids
=
data_start_token_ids
or
[
s
for
s
,
_e
in
data_token_pairs
]
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
...
...
@@ -79,7 +95,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
for
start_idx
,
end_idx
in
zip
(
start_indices
,
end_indices
):
padded_ids
.
extend
(
input_ids
[
last_idx
:
start_idx
+
1
])
if
input_ids
[
start_idx
]
in
start_token_ids
:
if
input_ids
[
start_idx
]
in
self
.
data_
start_token_ids
:
data_idx
+=
1
mm_inputs
.
data_offsets
+=
[
start_idx
]
...
...
@@ -170,7 +186,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
output_ids_tensor
[
start_idx
:
end_idx
]
=
pad_value
else
:
logger
.
warning
(
f
"Skipping region
{
i
}
due to None pad_value."
)
return
output_ids_tensor
.
tolist
()
...
...
@@ -202,7 +217,7 @@ def get_embedding_and_mask(
num_mm_tokens_in_input_ids
=
special_multimodal_mask
.
sum
().
item
()
if
num_mm_tokens_in_input_ids
!=
num_mm_tokens_in_embedding
:
logger
.
warning
(
f
"Number of tokens in multimodal embedding does not match those in the input text."
f
"Number of tokens in multimodal embedding does not match those in the input text.
"
f
"Got
{
num_mm_tokens_in_input_ids
}
tokens in the text but
{
num_mm_tokens_in_embedding
}
"
"tokens from multimodal embeddings."
)
...
...
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
01dd39ba
...
...
@@ -36,9 +36,21 @@ class BaseMultiModalProcessorOutput:
@
dataclasses
.
dataclass
class
MultimodalSpecialTokens
:
image_token
:
Optional
[
str
]
=
None
video_token
:
Optional
[
str
]
=
None
audio_token
:
Optional
[
str
]
=
None
image_token
:
Optional
[
Union
[
int
,
str
,
List
[
str
]]]
=
None
video_token
:
Optional
[
Union
[
int
,
str
,
List
[
str
]]]
=
None
audio_token
:
Optional
[
Union
[
int
,
str
,
List
[
str
]]]
=
None
def
convert_to_str
(
self
,
token
:
Union
[
str
,
int
],
processor
)
->
str
:
if
token
is
None
:
return
token
if
isinstance
(
token
,
str
):
return
token
return
processor
.
tokenizer
.
convert_ids_to_tokens
([
token
])[
0
]
def
convert_to_strs
(
self
,
processor
):
self
.
image_token
=
self
.
convert_to_str
(
self
.
image_token
,
processor
)
self
.
video_token
=
self
.
convert_to_str
(
self
.
video_token
,
processor
)
self
.
audio_token
=
self
.
convert_to_str
(
self
.
audio_token
,
processor
)
image_token_regex
:
Optional
[
re
.
Pattern
]
=
None
video_token_regex
:
Optional
[
re
.
Pattern
]
=
None
...
...
@@ -74,6 +86,7 @@ class BaseMultimodalProcessor(ABC):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
self
.
hf_config
=
hf_config
self
.
_processor
=
_processor
self
.
arch
=
hf_config
.
architectures
[
0
]
self
.
server_args
=
server_args
# FIXME: not accurate, model and image specific
self
.
NUM_TOKEN_PER_FRAME
=
330
...
...
@@ -260,19 +273,10 @@ class BaseMultimodalProcessor(ABC):
"""
if
not
return_text
:
raise
NotImplementedError
()
if
image_data
is
None
:
image_data
=
[]
if
isinstance
(
multimodal_tokens
.
image_token
,
int
):
multimodal_tokens
.
image_token
=
re
.
compile
(
re
.
escape
(
self
.
_processor
.
tokenizer
.
convert_ids_to_tokens
(
multimodal_tokens
.
image_token
)
)
)
else
:
multimodal_tokens
.
image_token
=
multimodal_tokens
.
image_token
multimodal_tokens
.
convert_to_strs
(
self
.
_processor
)
multimodal_tokens_pattern
=
multimodal_tokens
.
collect
()
if
isinstance
(
prompt
,
list
)
and
return_text
:
...
...
@@ -332,9 +336,9 @@ class BaseMultimodalProcessor(ABC):
new_text
+=
text_part
out
=
BaseMultiModalProcessorOutput
(
input_text
=
new_text
,
images
=
images
,
audios
=
audios
,
input_text
=
new_text
,
)
out
.
normalize
()
return
out
...
...
python/sglang/srt/managers/multimodal_processors/minicpm.py
View file @
01dd39ba
from
typing
import
List
,
Union
import
torch
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
...
...
@@ -21,33 +20,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self
.
image_token
=
"(<image>./</image>)"
self
.
audio_token
=
"(<audio>./</audio>)"
def
process_data_task
(
self
,
input_text
,
images
=
None
,
audios
=
None
):
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
if
isinstance
(
audios
,
list
)
and
len
(
audios
)
==
0
:
audios
=
None
processor
=
self
.
_processor
args
=
{}
if
isinstance
(
processor
,
BaseImageProcessorFast
):
args
[
"device"
]
=
"cuda"
result
=
self
.
_processor
.
__call__
(
text
=
input_text
,
images
=
images
,
audios
=
audios
,
return_tensors
=
"pt"
,
chunk_input
=
True
,
**
args
,
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
getattr
(
result
,
"pixel_values"
,
None
),
"tgt_sizes"
:
getattr
(
result
,
"tgt_sizes"
,
None
),
"audio_features"
:
getattr
(
result
,
"audio_features"
,
None
),
"audio_feature_lens"
:
getattr
(
result
,
"audio_feature_lens"
,
None
),
"audio_bounds"
:
getattr
(
result
,
"audio_bounds"
,
None
),
}
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
01dd39ba
...
...
@@ -324,8 +324,9 @@ class MultimodalInputs:
video_token_id
:
Optional
[
int
]
=
None
# audio
audio_start_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_end_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_token_id
:
Optional
[
int
]
=
None
audio_start_id
:
Optional
[
int
]
=
None
audio_end_id
:
Optional
[
int
]
=
None
@
staticmethod
def
from_dict
(
obj
:
dict
):
...
...
@@ -349,6 +350,7 @@ class MultimodalInputs:
"slice_end_id"
,
"audio_start_id"
,
"audio_end_id"
,
"audio_token_id"
,
]
for
arg
in
optional_args
:
if
arg
in
obj
:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
01dd39ba
...
...
@@ -459,7 +459,9 @@ class TokenizerManager:
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
image_inputs
:
Dict
=
await
self
.
mm_processor
.
process_mm_data_async
(
image_inputs
:
Optional
[
Dict
]
=
None
if
obj
.
contains_mm_input
():
image_inputs
=
await
self
.
mm_processor
.
process_mm_data_async
(
image_data
=
obj
.
image_data
,
input_text
=
input_text
or
input_ids
,
request_obj
=
obj
,
...
...
python/sglang/srt/mm_utils.py
View file @
01dd39ba
...
...
@@ -36,6 +36,16 @@ from io import BytesIO
import
numpy
as
np
from
PIL
import
Image
from
sglang.srt.utils
import
flatten_nested_list
def
has_valid_data
(
data
)
->
bool
:
if
data
is
None
:
return
False
if
isinstance
(
data
,
list
):
return
any
(
has_valid_data
(
item
)
for
item
in
flatten_nested_list
(
data
))
return
True
def
select_best_resolution
(
original_size
,
possible_resolutions
):
"""
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
01dd39ba
...
...
@@ -1165,7 +1165,7 @@ class ModelRunner:
def
model_is_mrope
(
self
)
->
bool
:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling
=
getattr
(
self
.
model_config
.
hf_config
,
"rope_scaling"
,
{})
rope_scaling
=
getattr
(
self
.
model_config
.
hf_
text_
config
,
"rope_scaling"
,
{})
if
rope_scaling
is
None
:
return
False
is_mrope_enabled
=
"mrope_section"
in
rope_scaling
...
...
python/sglang/srt/models/minicpmo.py
View file @
01dd39ba
...
...
@@ -1520,12 +1520,15 @@ class MiniCPMO(MiniCPMBaseModel):
slice_start_id
:
int
=
mm_input
.
slice_start_id
slice_end_id
:
int
=
mm_input
.
slice_end_id
medi
a_token_pairs
=
[
dat
a_token_pairs
=
[
(
im_start_id
,
im_end_id
),
(
slice_start_id
,
slice_end_id
),
(
mm_input
.
audio_start_id
,
mm_input
.
audio_end_id
),
]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
data_start_token_ids
=
[
im_start_id
,
mm_input
.
audio_start_id
]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
data_token_pairs
=
data_token_pairs
,
data_start_token_ids
=
data_start_token_ids
)
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_input
)
...
...
python/sglang/srt/models/mllama.py
View file @
01dd39ba
...
...
@@ -865,7 +865,6 @@ class MllamaForConditionalGeneration(nn.Module):
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values
for
item
in
mm_input
.
mm_items
],
dim
=
0
)
# max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
max_num_images
=
max
(
max_num_images
,
pixel_values
.
shape
[
1
])
max_num_tiles
=
max
(
max_num_tiles
,
pixel_values
.
shape
[
2
])
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
01dd39ba
...
...
@@ -146,6 +146,8 @@ class Qwen2_5_VisionBlock(nn.Module):
num_heads
=
num_heads
,
projection_size
=
dim
,
use_qkv_parallel
=
True
,
rotary_embed
=
"normal"
,
proj_bias
=
True
,
qkv_backend
=
qkv_backend
,
softmax_in_single_precision
=
softmax_in_single_precision
,
flatten_batch
=
flatten_batch
,
...
...
test/srt/test_vision_chunked_prefill.py
View file @
01dd39ba
...
...
@@ -147,8 +147,8 @@ class TestVisionChunkedPrefill(CustomTestCase):
def
_test_chunked_prefill
(
self
,
batches
,
num_frames
):
# Chunked
try
:
chunked_server_pid
=
self
.
launch_server
(
chunked_prefill_size
=
1024
)
try
:
outputs_chunked
=
[]
for
batch
,
num_frame
in
zip
(
batches
,
num_frames
):
output_chunked
=
self
.
generate_for_video
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment