Unverified Commit 07cadab2 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Model][Qwen3VL] Cache positional embedding indices (#28475)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 637f2921
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
"""Inference-only Qwen3VL model compatible with HuggingFace weights.""" """Inference-only Qwen3VL model compatible with HuggingFace weights."""
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial from functools import lru_cache, partial
from itertools import islice from itertools import islice
from typing import Any from typing import Any
...@@ -416,30 +416,41 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -416,30 +416,41 @@ class Qwen3_VisionTransformer(nn.Module):
def device(self) -> torch.device: def device(self) -> torch.device:
return self.patch_embed.proj.weight.device return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw: list[list[int]]): @staticmethod
pos_ids = [] @lru_cache(maxsize=1024)
max_grid_size = max(max(h, w) for _, h, w in grid_thw) def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
for t, h, w in grid_thw: hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) h_div = h // spatial_merge_size
w_div = w // spatial_merge_size
hpos_ids = hpos_ids.reshape( hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size, h_div,
self.spatial_merge_size, spatial_merge_size,
w // self.spatial_merge_size, w_div,
self.spatial_merge_size, spatial_merge_size,
) )
hpos_ids = hpos_ids.permute(0, 2, 1, 3) hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten() hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
wpos_ids = wpos_ids.reshape( wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size, h_div,
self.spatial_merge_size, spatial_merge_size,
w // self.spatial_merge_size, w_div,
self.spatial_merge_size, spatial_merge_size,
) )
wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten() wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))
def rot_pos_emb(self, grid_thw: list[list[int]]):
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
pos_ids = [
self.rot_pos_ids(h, w, self.spatial_merge_size)
if t == 1
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
for t, h, w in grid_thw
]
pos_ids = torch.cat(pos_ids, dim=0) pos_ids = torch.cat(pos_ids, dim=0)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment