Unverified Commit 86ee9491 authored by yuafng's avatar yuafng Committed by GitHub
Browse files

Fix tensor device and dtype placement in Qwen2VL model (#26219)


Signed-off-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarYuanfeng Li <yuanfengli@meta.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 4570535e
...@@ -720,7 +720,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -720,7 +720,7 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb = self.rot_pos_emb(grid_thw) rotary_pos_emb = self.rot_pos_emb(grid_thw)
# compute cu_seqlens # compute cu_seqlens
grid_thw_ = torch.tensor(grid_thw) grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2], cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
grid_thw_[:, 0]).cumsum( grid_thw_[:, 0]).cumsum(
dim=0, dtype=torch.int32) dim=0, dtype=torch.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment