Unverified Commit fc679696 authored by Chatcharin Sangbutsarakum's avatar Chatcharin Sangbutsarakum Committed by GitHub
Browse files

Fix `DotsOCR` tensor type (#26281)


Signed-off-by: default avatarwhat_in_the_nim <chatcharinsang@gmail.com>
parent ab5e7d93
...@@ -617,7 +617,7 @@ class DotsVisionTransformer(nn.Module): ...@@ -617,7 +617,7 @@ class DotsVisionTransformer(nn.Module):
def device(self) -> torch.device: def device(self) -> torch.device:
return self.patch_embed.patchifier.proj.weight.device return self.patch_embed.patchifier.proj.weight.device
def get_pos_ids_by_grid(self, grid_thw): def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]:
pos_ids = [] pos_ids = []
for t, h, w in grid_thw: for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
...@@ -643,10 +643,10 @@ class DotsVisionTransformer(nn.Module): ...@@ -643,10 +643,10 @@ class DotsVisionTransformer(nn.Module):
return pos_ids return pos_ids
def rot_pos_emb(self, grid_thw): def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
pos_ids = self.get_pos_ids_by_grid(grid_thw) pos_ids = self.get_pos_ids_by_grid(grid_thw)
pos_ids = torch.cat(pos_ids, dim=0) pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max() max_grid_size = max(max(h, w) for _, h, w in grid_thw)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb return rotary_pos_emb
...@@ -667,13 +667,13 @@ class DotsVisionTransformer(nn.Module): ...@@ -667,13 +667,13 @@ class DotsVisionTransformer(nn.Module):
def forward( def forward(
self, hidden_states: torch.Tensor, grid_thw: list[list[int]] self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
) -> torch.Tensor: ) -> torch.Tensor:
rotary_pos_emb = self.rot_pos_emb(grid_thw)
# Convert grid_thw to tensor (always expecting list format now) # Convert grid_thw to tensor (always expecting list format now)
grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long) grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
hidden_states = hidden_states.to(self.dtype) hidden_states = hidden_states.to(self.dtype)
hidden_states = self.patch_embed(hidden_states, grid_thw) hidden_states = self.patch_embed(hidden_states, grid_thw)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum( ).cumsum(
...@@ -807,7 +807,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -807,7 +807,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
rope_type="rope_3d", rope_type="rope_3d",
) )
else: else:
image_embeds = self.vision_tower(pixel_values, grid_thw)[ image_embeds = self.vision_tower(pixel_values, grid_thw_list)[
:, : self.config.hidden_size :, : self.config.hidden_size
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment