Unverified Commit 6ca2c91b authored by Itay Etelis's avatar Itay Etelis Committed by GitHub
Browse files

[Model] Use mm_position to compute mrope positions for Qwen3-Omni (#33010)


Signed-off-by: default avatarItay Etelis <itay.etelis@ibm.com>
Co-authored-by: default avatarItay Etelis <itay.etelis@ibm.com>
parent e33192b2
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
This example shows how to use vLLM for running offline inference This example shows how to use vLLM for running offline inference
with the correct prompt format on Qwen2.5-Omni (thinker only). with the correct prompt format on Qwen3-Omni (thinker only).
""" """
from typing import NamedTuple from typing import NamedTuple
...@@ -112,23 +112,51 @@ def get_multi_audios_query() -> QueryResult: ...@@ -112,23 +112,51 @@ def get_multi_audios_query() -> QueryResult:
) )
def get_multi_images_query() -> QueryResult:
question = "What are the differences between these two images?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
"<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"image": [
convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"),
convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB"),
],
},
},
limit_mm_per_prompt={
"image": 2,
},
)
query_map = { query_map = {
"mixed_modalities": get_mixed_modalities_query, "mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query, "use_audio_in_video": get_use_audio_in_video_query,
"multi_audios": get_multi_audios_query, "multi_audios": get_multi_audios_query,
"multi_images": get_multi_images_query,
} }
def main(args): def main(args):
model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct" model_name = args.model
query_result = query_map[args.query_type]() query_result = query_map[args.query_type]()
llm = LLM( llm = LLM(
model=model_name, model=model_name,
max_model_len=12800, max_model_len=args.max_model_len,
max_num_seqs=5, max_num_seqs=5,
limit_mm_per_prompt=query_result.limit_mm_per_prompt, limit_mm_per_prompt=query_result.limit_mm_per_prompt,
seed=args.seed, seed=args.seed,
tensor_parallel_size=args.tensor_parallel_size,
gpu_memory_utilization=args.gpu_memory_utilization,
) )
# We set temperature to 0.2 so that outputs can be different # We set temperature to 0.2 so that outputs can be different
...@@ -161,6 +189,31 @@ def parse_args(): ...@@ -161,6 +189,31 @@ def parse_args():
default=0, default=0,
help="Set the seed when initializing `vllm.LLM`.", help="Set the seed when initializing `vllm.LLM`.",
) )
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
help="Model name or path.",
)
parser.add_argument(
"--tensor-parallel-size",
"-tp",
type=int,
default=1,
help="Tensor parallel size for distributed inference.",
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.9,
help="GPU memory utilization (0.0 to 1.0).",
)
parser.add_argument(
"--max-model-len",
type=int,
default=12800,
help="Maximum model context length.",
)
return parser.parse_args() return parser.parse_args()
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part).""" """Inference-only Qwen3-Omni-Moe model (thinker part)."""
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial from functools import partial
from typing import Any from typing import Any
...@@ -104,10 +104,7 @@ from .utils import ( ...@@ -104,10 +104,7 @@ from .utils import (
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
maybe_prefix, maybe_prefix,
) )
from .vision import ( from .vision import get_vit_attn_backend
get_llm_pos_ids_for_vision,
get_vit_attn_backend,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1867,323 +1864,268 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1867,323 +1864,268 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
return loaded_weights return loaded_weights
def get_mrope_input_positions( def _compute_audio_token_count(self, audio_feature_length: int) -> int:
self, """Compute audio tokens from feature length using Qwen3-Omni formula."""
input_tokens: list[int], return _get_feat_extract_output_lengths(
mm_features: list[MultiModalFeatureSpec], torch.tensor([audio_feature_length])
) -> tuple[torch.Tensor, int]: ).item()
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"audio_feature_lengths",
"use_audio_in_video",
},
)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
audio_feature_lengths = kwargs.get("audio_feature_lengths", [])
use_audio_in_video = any(kwargs.get("use_audio_in_video", []))
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( def _get_audio_for_video_mapping(
image_grid_thw self, mm_features: list[MultiModalFeatureSpec]
) ) -> tuple[dict[int, int], set[int]]:
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( """
video_grid_thw Map video offset -> paired audio_feature_length for use_audio_in_video.
)
When use_audio_in_video=True, audio is interleaved within video.
The pairing is based on feature order in mm_features.
input_ids = torch.tensor(input_tokens) Returns:
if input_ids is None or input_ids.ndim != 1: Tuple of (video_offset -> audio_feature_length mapping,
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") set of paired audio offsets to skip)
"""
videos_with_audio = [
f
for f in mm_features
if f.modality == "video"
and f.data.get("use_audio_in_video")
and f.data["use_audio_in_video"].data.item()
]
audios = [f for f in mm_features if f.modality == "audio"]
mapping: dict[int, int] = {}
paired_audio_offsets: set[int] = set()
for i, video_f in enumerate(videos_with_audio):
if i < len(audios):
audio_len = audios[i].data["audio_feature_lengths"].data.item()
mapping[video_f.mm_position.offset] = audio_len
paired_audio_offsets.add(audios[i].mm_position.offset)
return mapping, paired_audio_offsets
def iter_mm_features(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, str, dict[str, Any]]]:
"""
Iterate over multimodal features sorted by position offset.
seq_len = input_ids.shape[0] Yields: (offset, modality, feature_data) where feature_data contains:
- image: {"grid_t", "grid_h", "grid_w", "t_factor"}
- video: {"grid_t", "grid_h", "grid_w", "t_factor",
"use_audio_in_video", "audio_feature_length"}
- audio: {"audio_feature_length"}
"""
config = self.config
spatial_merge_size = config.vision_config.spatial_merge_size
position_id_per_seconds = config.position_id_per_seconds
if isinstance(audio_feature_lengths, list): sorted_features = sorted(mm_features, key=lambda f: f.mm_position.offset)
audio_feature_lengths = torch.tensor( audio_for_video, paired_audio_offsets = self._get_audio_for_video_mapping(
audio_feature_lengths, dtype=torch.long sorted_features
) )
if not len(second_per_grid_ts) and len(video_grid_thw): for mm_feature in sorted_features:
offset = mm_feature.mm_position.offset
modality = mm_feature.modality
if modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
yield (
offset,
"image",
{
"grid_t": t,
"grid_h": h // spatial_merge_size,
"grid_w": w // spatial_merge_size,
"t_factor": position_id_per_seconds,
},
)
elif modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
second_per_grid_ts = 2.0 second_per_grid_ts = 2.0
second_per_grids = ( if mm_feature.data.get("second_per_grid_ts"):
torch.ones(len(video_grid_thw), dtype=torch.float32) second_per_grid_ts = mm_feature.data[
* second_per_grid_ts "second_per_grid_ts"
].data.item()
use_audio_in_video = bool(
mm_feature.data.get("use_audio_in_video")
and mm_feature.data["use_audio_in_video"].data.item()
)
yield (
offset,
"video",
{
"grid_t": t,
"grid_h": h // spatial_merge_size,
"grid_w": w // spatial_merge_size,
"t_factor": second_per_grid_ts * position_id_per_seconds,
"use_audio_in_video": use_audio_in_video,
"audio_feature_length": audio_for_video.get(offset),
},
) )
else: elif modality == "audio":
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32) if offset not in paired_audio_offsets:
audio_len = mm_feature.data["audio_feature_lengths"].data.item()
yield offset, "audio", {"audio_feature_length": audio_len}
config = self.config def _compute_interleaved_positions(
spatial_merge_size = config.vision_config.spatial_merge_size self, start_idx: int, data: dict[str, Any]
image_token_id = config.image_token_id ) -> tuple[np.ndarray, int]:
video_token_id = config.video_token_id """
audio_token_id = config.audio_token_id Compute positions for interleaved video+audio using Qwen3 token-by-token
vision_start_token_id = config.vision_start_token_id interleaving logic.
audio_start_token_id = config.audio_start_token_id
position_id_per_seconds = config.position_id_per_seconds
vision_start_indices = torch.argwhere( Returns: (position_ids [3, N], total_token_count)
input_ids == vision_start_token_id """
).squeeze(1) grid_t = data["grid_t"]
if vision_start_indices.numel() > 0: grid_h = data["grid_h"]
vision_tokens = input_ids[vision_start_indices + 1] grid_w = data["grid_w"]
t_factor = data["t_factor"]
audio_feature_length = data["audio_feature_length"]
audio_len = self._compute_audio_token_count(audio_feature_length)
h_index = np.tile(
np.arange(grid_h).reshape(1, -1, 1), (grid_t, 1, grid_w)
).flatten()
w_index = np.tile(
np.arange(grid_w).reshape(1, 1, -1), (grid_t, grid_h, 1)
).flatten()
t_index_raw = np.arange(grid_t)
t_index_scaled = (t_index_raw * t_factor).astype(np.int64)
t_index = np.repeat(t_index_scaled, grid_h * grid_w)
video_pos = np.stack([t_index, h_index, w_index]) + start_idx
audio_pos = np.broadcast_to(np.arange(audio_len), (3, audio_len)) + start_idx
video_t_values = video_pos[0]
audio_t_values = audio_pos[0]
pos_ids_list: list[np.ndarray] = []
video_idx, audio_idx = 0, 0
num_video = grid_t * grid_h * grid_w
while video_idx < num_video and audio_idx < audio_len:
if video_t_values[video_idx] <= audio_t_values[audio_idx]:
pos_ids_list.append(video_pos[:, video_idx : video_idx + 1])
video_idx += 1
else: else:
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype) pos_ids_list.append(audio_pos[:, audio_idx : audio_idx + 1])
audio_nums = torch.sum(input_ids == audio_start_token_id) audio_idx += 1
image_nums = (vision_tokens == image_token_id).sum()
video_nums = ( if video_idx < num_video:
(vision_tokens == audio_start_token_id).sum() pos_ids_list.append(video_pos[:, video_idx:])
if use_audio_in_video if audio_idx < audio_len:
else (vision_tokens == video_token_id).sum() pos_ids_list.append(audio_pos[:, audio_idx:])
)
total_tokens = num_video + audio_len
return np.concatenate(pos_ids_list, axis=1), total_tokens
llm_pos_ids_list: list[torch.Tensor] = [] def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
"""Compute M-RoPE input positions using mm_features directly."""
seq_len = len(input_tokens)
llm_pos_ids_list: list[np.ndarray] = []
st = 0 st = 0
image_idx = 0
video_idx = 0
audio_idx = 0
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
multimodal_nums = (
image_nums + audio_nums
if use_audio_in_video
else image_nums + video_nums + audio_nums
) # noqa: E501
for _ in range(multimodal_nums): for offset, modality, data in self.iter_mm_features(mm_features):
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 text_len = offset - st
if (image_token_id in input_tokens or video_token_id in input_tokens) and ( st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
remain_videos > 0 or remain_images > 0
):
ed_vision_start = input_tokens.index(vision_start_token_id, st)
else:
ed_vision_start = len(input_tokens) + 1
if audio_token_id in input_tokens and remain_audios > 0:
ed_audio_start = input_tokens.index(audio_start_token_id, st)
else:
ed_audio_start = len(input_tokens) + 1
min_ed = min(ed_vision_start, ed_audio_start)
if min_ed == ed_audio_start: if text_len > 0:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
+ st_idx
) )
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 st_idx += text_len
audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx] bos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
) llm_pos_ids_list.append(bos_pos)
llm_pos_ids = ( st_idx += 1
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx if modality == "audio":
) audio_tokens = self._compute_audio_token_count(
llm_pos_ids_list.append(llm_pos_ids) data["audio_feature_length"]
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 )
eos_len = 1 audio_pos = (
llm_pos_ids_list.append( np.broadcast_to(np.arange(audio_tokens), (3, audio_tokens)) + st_idx
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) )
+ st_idx llm_pos_ids_list.append(audio_pos)
) st_idx = int(audio_pos.max()) + 1
st += text_len + bos_len + audio_len + eos_len
audio_idx += 1 eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
remain_audios -= 1 llm_pos_ids_list.append(eos_pos)
elif ( st = offset + 1 + audio_tokens + 1
min_ed == ed_vision_start
and input_ids[ed_vision_start + 1] == image_token_id elif modality == "image":
): grid_t = data["grid_t"]
text_len = min_ed - st grid_h = data["grid_h"]
if text_len != 0: grid_w = data["grid_w"]
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 t_factor = data["t_factor"]
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long) grid_indices = np.indices((grid_t, grid_h, grid_w))
.view(1, -1) if t_factor != 1.0:
.expand(3, -1) grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
+ st_idx llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 image_len = grid_t * grid_h * grid_w
bos_len = 1 st_idx = int(llm_pos_ids_list[-1].max()) + 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
+ st_idx llm_pos_ids_list.append(eos_pos)
) st = offset + 1 + image_len + 1
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = image_grid_thw[image_idx][0] elif modality == "video":
grid_hs = image_grid_thw[:, 1] grid_t = data["grid_t"]
grid_ws = image_grid_thw[:, 2] grid_h = data["grid_h"]
t_index = torch.arange(grid_t) * position_id_per_seconds grid_w = data["grid_w"]
llm_pos_ids = get_llm_pos_ids_for_vision( t_factor = data["t_factor"]
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
) if not data["use_audio_in_video"]:
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) grid_indices = np.indices((grid_t, grid_h, grid_w))
llm_pos_ids_list.append(llm_pos_ids) if t_factor != 1.0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
eos_len = 1 llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) video_len = grid_t * grid_h * grid_w
+ st_idx st_idx = int(llm_pos_ids_list[-1].max()) + 1
)
st += text_len + bos_len + image_len + eos_len eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
image_idx += 1 llm_pos_ids_list.append(eos_pos)
remain_images -= 1 st = offset + 1 + video_len + 1
elif (
min_ed == ed_vision_start
and input_ids[ed_vision_start + 1] == video_token_id
and not use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st += text_len + bos_len + video_len + eos_len
video_idx += 1
remain_videos -= 1
elif (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
and use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
bos_block = (
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(bos_block)
llm_pos_ids_list.append(bos_block)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
)
audio_llm_pos_ids = (
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
video_llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if (
video_llm_pos_ids[0][video_data_index]
<= audio_llm_pos_ids[0][audio_data_index]
):
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_data_index + 1
]
)
video_data_index += 1
else: else:
llm_pos_ids_list.append( audio_bos_pos = np.broadcast_to(np.array([st_idx - 1]), (3, 1))
audio_llm_pos_ids[ llm_pos_ids_list.append(audio_bos_pos)
:, audio_data_index : audio_data_index + 1
] pos_ids, _ = self._compute_interleaved_positions(st_idx, data)
) llm_pos_ids_list.append(pos_ids)
audio_data_index += 1 st_idx = int(pos_ids.max()) + 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append( eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
video_llm_pos_ids[ llm_pos_ids_list.append(eos_pos)
:, video_data_index : video_llm_pos_ids.shape[-1] llm_pos_ids_list.append(eos_pos)
]
) video_len = grid_t * grid_h * grid_w
if audio_data_index < audio_llm_pos_ids.shape[-1]: audio_len = self._compute_audio_token_count(
llm_pos_ids_list.append( data["audio_feature_length"]
audio_llm_pos_ids[
:, audio_data_index : audio_llm_pos_ids.shape[-1]
]
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
eos_block = (
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
) )
llm_pos_ids_list.append(eos_block) st = offset + 2 + video_len + audio_len + 2
llm_pos_ids_list.append(eos_block)
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
if st < len(input_tokens): if st < seq_len:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
text_len = len(input_tokens) - st text_len = seq_len - st
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
+ st_idx
) )
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
if llm_positions.shape[1] != seq_len: if llm_positions.shape[1] != seq_len:
raise RuntimeError("Position ids length mismatch with input ids length") raise RuntimeError("Position ids length mismatch with input ids length")
mrope_position_delta = llm_positions.max() + 1 - seq_len mrope_position_delta = int(llm_positions.max()) + 1 - seq_len
return llm_positions, mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment