Unverified Commit 01dd39ba authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: minor refactors regarding multimodal processing (#6187)

parent b3f3d610
...@@ -22,7 +22,11 @@ from typing import List, Optional, Set, Union ...@@ -22,7 +22,11 @@ from typing import List, Optional, Set, Union
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_hf_text_config,
)
from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, is_hip from sglang.srt.utils import get_bool_env_var, is_hip
...@@ -209,7 +213,13 @@ class ModelConfig: ...@@ -209,7 +213,13 @@ class ModelConfig:
# Cache attributes # Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id() self.hf_eos_token_id = self.get_hf_eos_token_id()
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
config = self.hf_config
# multimodal
self.image_token_id = getattr(config, "image_token_id", None) or getattr(
config, "image_token_index", None
)
@staticmethod @staticmethod
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs): def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
...@@ -423,31 +433,6 @@ class ModelConfig: ...@@ -423,31 +433,6 @@ class ModelConfig:
self.model_path = client.get_local_dir() self.model_path = client.get_local_dir()
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
class_name = config.architectures[0]
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr(config, "torch_dtype", torch.float16)
return config
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
if hasattr(config, "language_config"):
return config.language_config
else:
return config
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16, "half": torch.float16,
...@@ -537,6 +522,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal ...@@ -537,6 +522,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
multimodal_model_archs = [ multimodal_model_archs = [
"CLIPModel",
"DeepseekVL2ForCausalLM", "DeepseekVL2ForCausalLM",
"Gemma3ForConditionalGeneration", "Gemma3ForConditionalGeneration",
"Grok1VForCausalLM", "Grok1VForCausalLM",
...@@ -554,7 +540,6 @@ multimodal_model_archs = [ ...@@ -554,7 +540,6 @@ multimodal_model_archs = [
"MllamaForConditionalGeneration", "MllamaForConditionalGeneration",
"Qwen2VLForConditionalGeneration", "Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"CLIPModel",
"KimiVLForConditionalGeneration", "KimiVLForConditionalGeneration",
"InternVLChatModel", "InternVLChatModel",
] ]
......
...@@ -19,6 +19,7 @@ import warnings ...@@ -19,6 +19,7 @@ import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Type, Union from typing import Dict, Optional, Type, Union
import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
...@@ -65,6 +66,43 @@ def download_from_hf(model_path: str): ...@@ -65,6 +66,43 @@ def download_from_hf(model_path: str):
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"]) return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
if config.architectures is not None:
class_name = config.architectures[0]
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr(config, "torch_dtype", torch.float16)
return config
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
if hasattr(config, "language_config"):
return config.language_config
if hasattr(config, "thinker_config"):
# qwen2.5 omni
thinker_config = config.thinker_config
if hasattr(thinker_config, "text_config"):
setattr(
thinker_config.text_config,
"torch_dtype",
getattr(thinker_config, "torch_dtype", None),
)
return thinker_config.text_config
return thinker_config
else:
return config
def get_config( def get_config(
model: str, model: str,
trust_remote_code: bool, trust_remote_code: bool,
...@@ -80,13 +118,12 @@ def get_config( ...@@ -80,13 +118,12 @@ def get_config(
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
) )
text_config = get_hf_text_config(config=config)
# FIXME: Pour contents of janus-pro's langauge_config to first-level if isinstance(model, str) and text_config is not None:
if isinstance(model, str) and model.lower().startswith("deepseek-ai/janus-pro"): for key, val in text_config.__dict__.items():
assert hasattr(config, "language_config") if not hasattr(config, key) and getattr(text_config, key, None) is not None:
for key, val in config.language_config.__dict__.items():
setattr(config, key, val) setattr(config, key, val)
setattr(config, "architectures", ["MultiModalityCausalLM"])
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
...@@ -99,6 +136,9 @@ def get_config( ...@@ -99,6 +136,9 @@ def get_config(
if not hasattr(config, key): if not hasattr(config, key):
setattr(config, key, val) setattr(config, key, val)
if config.model_type == "multi_modality":
config.update({"architectures": ["MultiModalityCausalLM"]})
if model_override_args: if model_override_args:
config.update(model_override_args) config.update(model_override_args)
......
...@@ -120,7 +120,7 @@ class VisionSdpaAttention(nn.Module): ...@@ -120,7 +120,7 @@ class VisionSdpaAttention(nn.Module):
flatten_batch: bool = False, flatten_batch: bool = False,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
r""" r"""
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`. Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
Args: Args:
s: sequence length s: sequence length
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
......
...@@ -22,13 +22,15 @@ from dataclasses import dataclass, field ...@@ -22,13 +22,15 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from sglang.srt.mm_utils import has_valid_data
# handle serialization of Image for pydantic # handle serialization of Image for pydantic
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image from PIL.Image import Image
else: else:
Image = Any Image = Any
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason, flatten_nested_list
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -104,6 +106,9 @@ class GenerateReqInput: ...@@ -104,6 +106,9 @@ class GenerateReqInput:
bootstrap_port: Optional[Union[List[int], int]] = None bootstrap_port: Optional[Union[List[int], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
""" """
Normalize the batch size and arguments for the request. Normalize the batch size and arguments for the request.
...@@ -487,6 +492,9 @@ class EmbeddingReqInput: ...@@ -487,6 +492,9 @@ class EmbeddingReqInput:
# The modalities of the image data [image, multi-images, video] # The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided # at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None: if self.text is None and self.input_ids is None and self.image_data is None:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Multi-modality utils Multi-modality utils
""" """
import dataclasses
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
...@@ -41,11 +42,26 @@ class MultiModalityDataPaddingPattern: ...@@ -41,11 +42,26 @@ class MultiModalityDataPaddingPattern:
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern): class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs) """In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value
This strategy should be applied when data content is marked by start/end token pairs in the input sequence. This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
""" """
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None: def __init__(
self,
data_token_pairs: Optional[List[Tuple[int, int]]],
data_start_token_ids: Optional[List[int]] = None,
) -> None:
"""
Args:
data_start_token_ids marks the start of a single multimodal data
See Minicpmo's slice_start_id for example
"""
self.data_token_id_pairs = data_token_pairs self.data_token_id_pairs = data_token_pairs
self.data_start_token_ids = data_start_token_ids or [
s for s, _e in data_token_pairs
]
def pad_input_tokens( def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs self, input_ids: List[int], mm_inputs: MultimodalInputs
...@@ -79,7 +95,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) ...@@ -79,7 +95,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
for start_idx, end_idx in zip(start_indices, end_indices): for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1]) padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] in start_token_ids: if input_ids[start_idx] in self.data_start_token_ids:
data_idx += 1 data_idx += 1
mm_inputs.data_offsets += [start_idx] mm_inputs.data_offsets += [start_idx]
...@@ -170,7 +186,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa ...@@ -170,7 +186,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
output_ids_tensor[start_idx:end_idx] = pad_value output_ids_tensor[start_idx:end_idx] = pad_value
else: else:
logger.warning(f"Skipping region {i} due to None pad_value.") logger.warning(f"Skipping region {i} due to None pad_value.")
return output_ids_tensor.tolist() return output_ids_tensor.tolist()
...@@ -202,7 +217,7 @@ def get_embedding_and_mask( ...@@ -202,7 +217,7 @@ def get_embedding_and_mask(
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item() num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding: if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning( logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text." f"Number of tokens in multimodal embedding does not match those in the input text. "
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} " f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
"tokens from multimodal embeddings." "tokens from multimodal embeddings."
) )
......
...@@ -36,9 +36,21 @@ class BaseMultiModalProcessorOutput: ...@@ -36,9 +36,21 @@ class BaseMultiModalProcessorOutput:
@dataclasses.dataclass @dataclasses.dataclass
class MultimodalSpecialTokens: class MultimodalSpecialTokens:
image_token: Optional[str] = None image_token: Optional[Union[int, str, List[str]]] = None
video_token: Optional[str] = None video_token: Optional[Union[int, str, List[str]]] = None
audio_token: Optional[str] = None audio_token: Optional[Union[int, str, List[str]]] = None
def convert_to_str(self, token: Union[str, int], processor) -> str:
if token is None:
return token
if isinstance(token, str):
return token
return processor.tokenizer.convert_ids_to_tokens([token])[0]
def convert_to_strs(self, processor):
self.image_token = self.convert_to_str(self.image_token, processor)
self.video_token = self.convert_to_str(self.video_token, processor)
self.audio_token = self.convert_to_str(self.audio_token, processor)
image_token_regex: Optional[re.Pattern] = None image_token_regex: Optional[re.Pattern] = None
video_token_regex: Optional[re.Pattern] = None video_token_regex: Optional[re.Pattern] = None
...@@ -74,6 +86,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -74,6 +86,7 @@ class BaseMultimodalProcessor(ABC):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config self.hf_config = hf_config
self._processor = _processor self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args self.server_args = server_args
# FIXME: not accurate, model and image specific # FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330 self.NUM_TOKEN_PER_FRAME = 330
...@@ -260,19 +273,10 @@ class BaseMultimodalProcessor(ABC): ...@@ -260,19 +273,10 @@ class BaseMultimodalProcessor(ABC):
""" """
if not return_text: if not return_text:
raise NotImplementedError() raise NotImplementedError()
if image_data is None: if image_data is None:
image_data = [] image_data = []
if isinstance(multimodal_tokens.image_token, int):
multimodal_tokens.image_token = re.compile( multimodal_tokens.convert_to_strs(self._processor)
re.escape(
self._processor.tokenizer.convert_ids_to_tokens(
multimodal_tokens.image_token
)
)
)
else:
multimodal_tokens.image_token = multimodal_tokens.image_token
multimodal_tokens_pattern = multimodal_tokens.collect() multimodal_tokens_pattern = multimodal_tokens.collect()
if isinstance(prompt, list) and return_text: if isinstance(prompt, list) and return_text:
...@@ -332,9 +336,9 @@ class BaseMultimodalProcessor(ABC): ...@@ -332,9 +336,9 @@ class BaseMultimodalProcessor(ABC):
new_text += text_part new_text += text_part
out = BaseMultiModalProcessorOutput( out = BaseMultiModalProcessorOutput(
input_text=new_text,
images=images, images=images,
audios=audios, audios=audios,
input_text=new_text,
) )
out.normalize() out.normalize()
return out return out
......
from typing import List, Union from typing import List, Union
import torch import torch
from transformers import BaseImageProcessorFast
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
...@@ -21,33 +20,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -21,33 +20,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self.image_token = "(<image>./</image>)" self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)" self.audio_token = "(<audio>./</audio>)"
def process_data_task(self, input_text, images=None, audios=None):
if isinstance(images, list) and len(images) == 0:
images = None
if isinstance(audios, list) and len(audios) == 0:
audios = None
processor = self._processor
args = {}
if isinstance(processor, BaseImageProcessorFast):
args["device"] = "cuda"
result = self._processor.__call__(
text=input_text,
images=images,
audios=audios,
return_tensors="pt",
chunk_input=True,
**args,
)
return {
"input_ids": result.input_ids,
"pixel_values": getattr(result, "pixel_values", None),
"tgt_sizes": getattr(result, "tgt_sizes", None),
"audio_features": getattr(result, "audio_features", None),
"audio_feature_lens": getattr(result, "audio_feature_lens", None),
"audio_bounds": getattr(result, "audio_bounds", None),
}
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
......
...@@ -324,8 +324,9 @@ class MultimodalInputs: ...@@ -324,8 +324,9 @@ class MultimodalInputs:
video_token_id: Optional[int] = None video_token_id: Optional[int] = None
# audio # audio
audio_start_id: Optional[torch.Tensor] = None audio_token_id: Optional[int] = None
audio_end_id: Optional[torch.Tensor] = None audio_start_id: Optional[int] = None
audio_end_id: Optional[int] = None
@staticmethod @staticmethod
def from_dict(obj: dict): def from_dict(obj: dict):
...@@ -349,6 +350,7 @@ class MultimodalInputs: ...@@ -349,6 +350,7 @@ class MultimodalInputs:
"slice_end_id", "slice_end_id",
"audio_start_id", "audio_start_id",
"audio_end_id", "audio_end_id",
"audio_token_id",
] ]
for arg in optional_args: for arg in optional_args:
if arg in obj: if arg in obj:
......
...@@ -459,7 +459,9 @@ class TokenizerManager: ...@@ -459,7 +459,9 @@ class TokenizerManager:
) )
input_ids = self.tokenizer.encode(input_text) input_ids = self.tokenizer.encode(input_text)
image_inputs: Dict = await self.mm_processor.process_mm_data_async( image_inputs: Optional[Dict] = None
if obj.contains_mm_input():
image_inputs = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data, image_data=obj.image_data,
input_text=input_text or input_ids, input_text=input_text or input_ids,
request_obj=obj, request_obj=obj,
......
...@@ -36,6 +36,16 @@ from io import BytesIO ...@@ -36,6 +36,16 @@ from io import BytesIO
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from sglang.srt.utils import flatten_nested_list
def has_valid_data(data) -> bool:
if data is None:
return False
if isinstance(data, list):
return any(has_valid_data(item) for item in flatten_nested_list(data))
return True
def select_best_resolution(original_size, possible_resolutions): def select_best_resolution(original_size, possible_resolutions):
""" """
......
...@@ -1165,7 +1165,7 @@ class ModelRunner: ...@@ -1165,7 +1165,7 @@ class ModelRunner:
def model_is_mrope(self) -> bool: def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type. """Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases.""" mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
if rope_scaling is None: if rope_scaling is None:
return False return False
is_mrope_enabled = "mrope_section" in rope_scaling is_mrope_enabled = "mrope_section" in rope_scaling
......
...@@ -1520,12 +1520,15 @@ class MiniCPMO(MiniCPMBaseModel): ...@@ -1520,12 +1520,15 @@ class MiniCPMO(MiniCPMBaseModel):
slice_start_id: int = mm_input.slice_start_id slice_start_id: int = mm_input.slice_start_id
slice_end_id: int = mm_input.slice_end_id slice_end_id: int = mm_input.slice_end_id
media_token_pairs = [ data_token_pairs = [
(im_start_id, im_end_id), (im_start_id, im_end_id),
(slice_start_id, slice_end_id), (slice_start_id, slice_end_id),
(mm_input.audio_start_id, mm_input.audio_end_id), (mm_input.audio_start_id, mm_input.audio_end_id),
] ]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) data_start_token_ids = [im_start_id, mm_input.audio_start_id]
pattern = MultiModalityDataPaddingPatternTokenPairs(
data_token_pairs=data_token_pairs, data_start_token_ids=data_start_token_ids
)
return pattern.pad_input_tokens(input_ids, mm_input) return pattern.pad_input_tokens(input_ids, mm_input)
......
...@@ -865,7 +865,6 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -865,7 +865,6 @@ class MllamaForConditionalGeneration(nn.Module):
pixel_values = torch.cat( pixel_values = torch.cat(
[item.pixel_values for item in mm_input.mm_items], dim=0 [item.pixel_values for item in mm_input.mm_items], dim=0
) )
# max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
max_num_images = max(max_num_images, pixel_values.shape[1]) max_num_images = max(max_num_images, pixel_values.shape[1])
max_num_tiles = max(max_num_tiles, pixel_values.shape[2]) max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
......
...@@ -146,6 +146,8 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -146,6 +146,8 @@ class Qwen2_5_VisionBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
use_qkv_parallel=True, use_qkv_parallel=True,
rotary_embed="normal",
proj_bias=True,
qkv_backend=qkv_backend, qkv_backend=qkv_backend,
softmax_in_single_precision=softmax_in_single_precision, softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=flatten_batch, flatten_batch=flatten_batch,
......
...@@ -147,8 +147,8 @@ class TestVisionChunkedPrefill(CustomTestCase): ...@@ -147,8 +147,8 @@ class TestVisionChunkedPrefill(CustomTestCase):
def _test_chunked_prefill(self, batches, num_frames): def _test_chunked_prefill(self, batches, num_frames):
# Chunked # Chunked
try:
chunked_server_pid = self.launch_server(chunked_prefill_size=1024) chunked_server_pid = self.launch_server(chunked_prefill_size=1024)
try:
outputs_chunked = [] outputs_chunked = []
for batch, num_frame in zip(batches, num_frames): for batch, num_frame in zip(batches, num_frames):
output_chunked = self.generate_for_video( output_chunked = self.generate_for_video(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment