Unverified Commit 4395c87a authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: unify names of the feature field of MultimodalDataItem (#8075)

parent c28ad199
...@@ -78,7 +78,7 @@ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor): ...@@ -78,7 +78,7 @@ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
output_lengths = (input_lengths - 2) // 2 + 1 output_lengths = (input_lengths - 2) // 2 + 1
item = MultimodalDataItem( item = MultimodalDataItem(
audio_features=res["input_features"], feature=res["input_features"],
audio_feature_lens=output_lengths, audio_feature_lens=output_lengths,
audio_offsets=audio_offsets, audio_offsets=audio_offsets,
modality=Modality.AUDIO, modality=Modality.AUDIO,
......
...@@ -207,13 +207,12 @@ class MultimodalDataItem: ...@@ -207,13 +207,12 @@ class MultimodalDataItem:
modality: Modality modality: Modality
hash: int = None hash: int = None
pad_value: int = None pad_value: int = None
image_sizes: Tuple[int, int] = None
offsets: Optional[list] = None offsets: Optional[list] = None
# the raw features returned by processor, e.g. pixel_values or audio_features
feature: Union[torch.Tensor, np.ndarray] = None
image_sizes: Tuple[int, int] = None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None audio_offsets: Optional[List[Tuple[int, int]]] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
...@@ -238,7 +237,6 @@ class MultimodalDataItem: ...@@ -238,7 +237,6 @@ class MultimodalDataItem:
image_grid_hws: Optional[List[torch.Tensor]] = None image_grid_hws: Optional[List[torch.Tensor]] = None
# For gemma3n # For gemma3n
input_features: Optional[torch.Tensor] = None
input_features_mask: Optional[torch.Tensor] = None input_features_mask: Optional[torch.Tensor] = None
@staticmethod @staticmethod
...@@ -254,18 +252,11 @@ class MultimodalDataItem: ...@@ -254,18 +252,11 @@ class MultimodalDataItem:
from sglang.srt.managers.mm_utils import hash_feature from sglang.srt.managers.mm_utils import hash_feature
if self.hash is None: if self.hash is None:
if self.precomputed_features is not None: if self.feature is not None:
self.hash = hash_feature(self.precomputed_features) hashed_feature = self.feature
elif self.is_audio():
if self.audio_features is not None:
self.hash = hash_feature(self.audio_features)
elif self.input_features is not None:
self.hash = hash_feature(self.input_features)
elif self.is_video():
self.hash = hash_feature(self.pixel_values_videos)
else: else:
self.hash = hash_feature(self.pixel_values) hashed_feature = self.precomputed_features
self.hash = hash_feature(hashed_feature)
assert self.hash is not None assert self.hash is not None
self.pad_value = self.hash % (1 << 30) self.pad_value = self.hash % (1 << 30)
...@@ -275,8 +266,7 @@ class MultimodalDataItem: ...@@ -275,8 +266,7 @@ class MultimodalDataItem:
def is_audio(self): def is_audio(self):
return (self.modality == Modality.AUDIO) and ( return (self.modality == Modality.AUDIO) and (
self.precomputed_features is not None self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.audio_features) or not MultimodalDataItem.is_empty_list(self.feature)
or not MultimodalDataItem.is_empty_list(self.input_features)
) )
def is_image(self): def is_image(self):
...@@ -284,13 +274,13 @@ class MultimodalDataItem: ...@@ -284,13 +274,13 @@ class MultimodalDataItem:
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES) self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
) and ( ) and (
self.precomputed_features is not None self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values) or not MultimodalDataItem.is_empty_list(self.feature)
) )
def is_video(self): def is_video(self):
return (self.modality == Modality.VIDEO) and ( return (self.modality == Modality.VIDEO) and (
self.precomputed_features is not None self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values_videos) or not MultimodalDataItem.is_empty_list(self.feature)
) )
def is_valid(self) -> bool: def is_valid(self) -> bool:
...@@ -311,7 +301,7 @@ class MultimodalDataItem: ...@@ -311,7 +301,7 @@ class MultimodalDataItem:
return ret return ret
def merge(self, other): def merge(self, other):
self.pixel_values += other.pixel_values self.feature += other.feature
self.image_sizes += other.image_sizes self.image_sizes += other.image_sizes
self.image_offsets += other.image_offsets self.image_offsets += other.image_offsets
self.hash = hash((self.hash, other.hash)) self.hash = hash((self.hash, other.hash))
...@@ -354,7 +344,6 @@ class MultimodalInputs: ...@@ -354,7 +344,6 @@ class MultimodalInputs:
assert isinstance(ret.mm_items, list) assert isinstance(ret.mm_items, list)
ret.mm_items = [item for item in ret.mm_items if item.is_valid()] ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
for item in ret.mm_items: for item in ret.mm_items:
item.set_pad_value() item.set_pad_value()
...@@ -1278,11 +1267,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1278,11 +1267,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if mm_input is None: if mm_input is None:
continue continue
for mm_item in mm_input.mm_items: for mm_item in mm_input.mm_items:
pixel_values = getattr(mm_item, "pixel_values", None) pixel_values = getattr(mm_item, "feature", None)
if isinstance(pixel_values, torch.Tensor): if isinstance(pixel_values, torch.Tensor):
mm_item.pixel_values = pixel_values.to( mm_item.feature = pixel_values.to(self.device, non_blocking=True)
self.device, non_blocking=True
)
self.multimodal_inputs = multimodal_inputs self.multimodal_inputs = multimodal_inputs
self.token_type_ids = token_type_ids_tensor self.token_type_ids = token_type_ids_tensor
self.seq_lens_sum = sum(seq_lens) self.seq_lens_sum = sum(seq_lens)
......
...@@ -463,7 +463,7 @@ class CLIPModel(nn.Module): ...@@ -463,7 +463,7 @@ class CLIPModel(nn.Module):
if forward_batch.mm_inputs is not None: if forward_batch.mm_inputs is not None:
mm_inputs = forward_batch.mm_inputs mm_inputs = forward_batch.mm_inputs
pixel_values_list = [ pixel_values_list = [
item.pixel_values item.feature
for item in flatten_nested_list( for item in flatten_nested_list(
[mm_input.mm_items for mm_input in mm_inputs if mm_input is not None] [mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
) )
......
...@@ -1960,7 +1960,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): ...@@ -1960,7 +1960,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = torch.concat([item.pixel_values for item in items], dim=0) pixel_values = torch.concat([item.feature for item in items], dim=0)
bs, n = pixel_values.shape[0:2] bs, n = pixel_values.shape[0:2]
pixel_values = pixel_values.to( pixel_values = pixel_values.to(
device=self.vision_model.device, dtype=self.vision_model.dtype device=self.vision_model.device, dtype=self.vision_model.dtype
......
...@@ -268,9 +268,9 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -268,9 +268,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
# TODO: can it be batched ? # TODO: can it be batched ?
images_in_this_batch = [] images_in_this_batch = []
for item in items: for item in items:
assert item.pixel_values.dim() == 4 assert item.feature.dim() == 4
image_feature = self.vision.forward_features( image_feature = self.vision.forward_features(
item.pixel_values.type(next(self.vision.parameters()).dtype).to( item.feature.type(next(self.vision.parameters()).dtype).to(
device=next(self.vision.parameters()).device device=next(self.vision.parameters()).device
) )
) )
......
...@@ -283,7 +283,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -283,7 +283,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
""" """
# Process images one by one to handle flatten_batch=True constraint in vision_tower # Process images one by one to handle flatten_batch=True constraint in vision_tower
all_pixel_values = flatten_nested_list([item.pixel_values for item in items]) all_pixel_values = flatten_nested_list([item.feature for item in items])
vision_outputs_list = [] vision_outputs_list = []
for pixel_values_batch in all_pixel_values: for pixel_values_batch in all_pixel_values:
......
...@@ -265,7 +265,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): ...@@ -265,7 +265,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
""" """
# Process images one by one to handle flatten_batch=True constraint in vision_tower # Process images one by one to handle flatten_batch=True constraint in vision_tower
all_pixel_values = flatten_nested_list([item.pixel_values for item in items]) all_pixel_values = flatten_nested_list([item.feature for item in items])
vision_outputs_list = [] vision_outputs_list = []
for pixel_values_batch in all_pixel_values: for pixel_values_batch in all_pixel_values:
...@@ -316,9 +316,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): ...@@ -316,9 +316,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`). audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
""" """
# Extract audio features and masks from items # Extract audio features and masks from items
all_input_features = flatten_nested_list( all_input_features = flatten_nested_list([item.feature for item in items])
[item.input_features for item in items]
)
all_input_features_mask = flatten_nested_list( all_input_features_mask = flatten_nested_list(
[~item.input_features_mask for item in items] [~item.input_features_mask for item in items]
) # Note(Xinyuan): reverse the mask according to the HF implementation ) # Note(Xinyuan): reverse the mask according to the HF implementation
......
...@@ -510,7 +510,7 @@ class InternVLChatModel(nn.Module): ...@@ -510,7 +510,7 @@ class InternVLChatModel(nn.Module):
Returns: Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
""" """
pixel_values = torch.cat([item.pixel_values for item in items]) pixel_values = torch.cat([item.feature for item in items])
image_features = self.extract_feature(pixel_values) image_features = self.extract_feature(pixel_values)
return image_features return image_features
......
...@@ -144,7 +144,7 @@ class KimiVLForConditionalGeneration(nn.Module): ...@@ -144,7 +144,7 @@ class KimiVLForConditionalGeneration(nn.Module):
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = ( pixel_values = (
torch.cat([item.pixel_values for item in items], dim=0) torch.cat([item.feature for item in items], dim=0)
.type(self.vision_tower.dtype) .type(self.vision_tower.dtype)
.to(self.vision_tower.device) .to(self.vision_tower.device)
) )
......
...@@ -186,7 +186,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -186,7 +186,7 @@ class LlavaBaseForCausalLM(nn.Module):
bs = forward_batch.batch_size bs = forward_batch.batch_size
pixel_values = flatten_nested_list( pixel_values = flatten_nested_list(
[ [
[item.pixel_values for item in image_inputs[i].mm_items] [item.feature for item in image_inputs[i].mm_items]
for i in range(bs) for i in range(bs)
if need_vision[i] if need_vision[i]
] ]
...@@ -753,7 +753,7 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM): ...@@ -753,7 +753,7 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
features = [] features = []
for item in items: for item in items:
# in each item, we assume pixel_values is always batched # in each item, we assume pixel_values is always batched
pixel_values, image_sizes = item.pixel_values, item.image_sizes pixel_values, image_sizes = item.feature, item.image_sizes
image_outputs = self.vision_tower( image_outputs = self.vision_tower(
pixel_values, image_sizes, output_hidden_states=True pixel_values, image_sizes, output_hidden_states=True
) )
......
...@@ -135,7 +135,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -135,7 +135,7 @@ class LlavaVidForCausalLM(nn.Module):
if need_vision.any(): if need_vision.any():
pixel_values = flatten_nested_list( pixel_values = flatten_nested_list(
[ [
[item.pixel_values for item in image_inputs[i].mm_items] [item.feature for item in image_inputs[i].mm_items]
for i in range(bs) for i in range(bs)
if need_vision[i] if need_vision[i]
] ]
......
...@@ -1552,9 +1552,7 @@ class MiniCPMO(MiniCPMBaseModel): ...@@ -1552,9 +1552,7 @@ class MiniCPMO(MiniCPMBaseModel):
Returns: Returns:
List[List[torch.Tensor]]: audio embeddings List[List[torch.Tensor]]: audio embeddings
""" """
wavforms = flatten_nested_list( wavforms = flatten_nested_list([item.feature for item in items if item.feature])
[item.audio_features for item in items if item.audio_features]
)
# list, [[x1, x2], [y1], [z1]] # list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw = flatten_nested_list( audio_feature_lens_raw = flatten_nested_list(
[item.audio_feature_lens for item in items if item.audio_feature_lens] [item.audio_feature_lens for item in items if item.audio_feature_lens]
...@@ -1659,9 +1657,7 @@ class MiniCPMO(MiniCPMBaseModel): ...@@ -1659,9 +1657,7 @@ class MiniCPMO(MiniCPMBaseModel):
List[List[torch.Tensor]]: audio embeddings List[List[torch.Tensor]]: audio embeddings
""" """
# (bs, 80, frames) or [], multi audios need filled in advance # (bs, 80, frames) or [], multi audios need filled in advance
wavforms = flatten_nested_list( wavforms = flatten_nested_list([item.feature for item in items if item.feature])
[item.audio_features for item in items if item.audio_features]
)
# list, [[x1, x2], [y1], [z1]] # list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw = flatten_nested_list( audio_feature_lens_raw = flatten_nested_list(
[item.audio_feature_lens for item in items if item.audio_feature_lens] [item.audio_feature_lens for item in items if item.audio_feature_lens]
...@@ -1778,7 +1774,7 @@ class MiniCPMO(MiniCPMBaseModel): ...@@ -1778,7 +1774,7 @@ class MiniCPMO(MiniCPMBaseModel):
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# list of tensors # list of tensors
pixel_values = flatten_nested_list([item.pixel_values for item in items]) pixel_values = flatten_nested_list([item.feature for item in items])
tgt_sizes = torch.stack( tgt_sizes = torch.stack(
flatten_nested_list([item.tgt_size for item in items]), dim=0 flatten_nested_list([item.tgt_size for item in items]), dim=0
) )
......
...@@ -724,7 +724,7 @@ class MiniCPMV2_6(MiniCPMBaseModel): ...@@ -724,7 +724,7 @@ class MiniCPMV2_6(MiniCPMBaseModel):
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# list of tensors # list of tensors
pixel_values = flatten_nested_list([item.pixel_values for item in items]) pixel_values = flatten_nested_list([item.feature for item in items])
tgt_sizes = torch.stack( tgt_sizes = torch.stack(
flatten_nested_list([item.tgt_size for item in items]), dim=0 flatten_nested_list([item.tgt_size for item in items]), dim=0
) )
......
...@@ -56,7 +56,7 @@ class Mistral3ForConditionalGeneration: ...@@ -56,7 +56,7 @@ class Mistral3ForConditionalGeneration:
features = [] features = []
for item in items: for item in items:
# in each item, we assume pixel_values is always batched # in each item, we assume pixel_values is always batched
pixel_values, image_sizes = item.pixel_values, item.image_sizes pixel_values, image_sizes = item.feature, item.image_sizes
image_outputs = self.vision_tower( image_outputs = self.vision_tower(
pixel_values, image_sizes, output_hidden_states=True pixel_values, image_sizes, output_hidden_states=True
) )
......
...@@ -838,9 +838,7 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -838,9 +838,7 @@ class MllamaForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config.text_config) self.logits_processor = LogitsProcessor(config.text_config)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pixel_values = torch.cat( pixel_values = torch.cat([item.feature for item in mm_inputs.mm_items], dim=0)
[item.pixel_values for item in mm_inputs.mm_items], dim=0
)
pad_values = [item.pad_value for item in mm_inputs.mm_items] pad_values = [item.pad_value for item in mm_inputs.mm_items]
num_concurrent_media, num_tiles = pixel_values.shape[1:3] num_concurrent_media, num_tiles = pixel_values.shape[1:3]
...@@ -862,7 +860,7 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -862,7 +860,7 @@ class MllamaForConditionalGeneration(nn.Module):
if not forward_batch.encoder_cached[i] and mm_input is not None: if not forward_batch.encoder_cached[i] and mm_input is not None:
pixel_values = torch.cat( pixel_values = torch.cat(
[item.pixel_values for item in mm_input.mm_items], dim=0 [item.feature for item in mm_input.mm_items], dim=0
) )
max_num_images = max(max_num_images, pixel_values.shape[1]) max_num_images = max(max_num_images, pixel_values.shape[1])
...@@ -897,7 +895,7 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -897,7 +895,7 @@ class MllamaForConditionalGeneration(nn.Module):
encoder_lens_need.append(forward_batch.encoder_lens[k]) encoder_lens_need.append(forward_batch.encoder_lens[k])
pixel_values = torch.cat( pixel_values = torch.cat(
[item.pixel_values for item in mm_input.mm_items], dim=0 [item.feature for item in mm_input.mm_items], dim=0
) )
for j in range(pixel_values.shape[1]): for j in range(pixel_values.shape[1]):
img = pixel_values[0, j] img = pixel_values[0, j]
......
...@@ -147,7 +147,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -147,7 +147,7 @@ class Llama4ForConditionalGeneration(nn.Module):
raise ValueError("Vision model not available for text-only checkpoint") raise ValueError("Vision model not available for text-only checkpoint")
pixel_values = ( pixel_values = (
torch.concat([item.pixel_values for item in items]) torch.concat([item.feature for item in items])
.to(next(self.vision_model.parameters()).device) .to(next(self.vision_model.parameters()).device)
.type(next(self.vision_model.parameters()).dtype) .type(next(self.vision_model.parameters()).dtype)
) )
......
...@@ -422,9 +422,7 @@ class Phi4MMForCausalLM(nn.Module): ...@@ -422,9 +422,7 @@ class Phi4MMForCausalLM(nn.Module):
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
dtype = next(self.vision_encoder.parameters()).dtype dtype = next(self.vision_encoder.parameters()).dtype
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype)
dtype
)
image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0) image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0)
image_sizes = torch.cat([item.image_sizes for item in items], dim=0) image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
image_embeds = self.vision_encoder( image_embeds = self.vision_encoder(
......
...@@ -497,7 +497,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -497,7 +497,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same # in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype self.visual.dtype
) )
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
...@@ -508,9 +508,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -508,9 +508,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same # in qwen-vl, last dim is the same
pixel_values = torch.cat( pixel_values = torch.cat([item.feature for item in items], dim=0).type(
[getattr(item, "pixel_values_videos") for item in items], dim=0 self.visual.dtype
).type(self.visual.dtype) )
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0) video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim() assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim() assert video_grid_thw.dim() == 2, video_grid_thw.dim()
......
...@@ -118,7 +118,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module): ...@@ -118,7 +118,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# Extract audio features from input items # Extract audio features from input items
input_features = torch.cat([item.audio_features for item in items], dim=0).type( input_features = torch.cat([item.feature for item in items], dim=0).type(
self.audio_tower.dtype self.audio_tower.dtype
) )
......
...@@ -484,7 +484,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -484,7 +484,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same # in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype self.visual.dtype
) )
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
...@@ -495,9 +495,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -495,9 +495,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same # in qwen-vl, last dim is the same
pixel_values = torch.cat( pixel_values = torch.cat([item.feature for item in items], dim=0).type(
[item.pixel_values_videos for item in items], dim=0 self.visual.dtype
).type(self.visual.dtype) )
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0) video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim() assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim() assert video_grid_thw.dim() == 2, video_grid_thw.dim()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment