"vllm/vscode:/vscode.git/clone" did not exist on "4695397dcfef693a0a10f1eb8bf77ea905c54829"
Unverified Commit e283976f authored by WeiQing Chen's avatar WeiQing Chen Committed by GitHub
Browse files

[Performance][MM] Building the inverse permutation in O(n) time in...


[Performance][MM] Building the inverse permutation in O(n) time in Qwen2_5_VisionTransformer (#24443)
Signed-off-by: default avatarJunhong <liujunhong11@huawei.com>
Co-authored-by: default avatarJunhong <liujunhong11@huawei.com>
parent 46876dff
......@@ -717,6 +717,15 @@ class Qwen2_5_VisionTransformer(nn.Module):
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens
@staticmethod
def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
# building the inverse permutation in O(n) time
inv = torch.empty_like(perm)
inv[perm] = torch.arange(perm.numel(),
device=perm.device,
dtype=perm.dtype)
return inv
def forward(
self,
x: torch.Tensor,
......@@ -760,6 +769,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
rotary_pos_emb = torch.cat(rotary_pos_emb)
window_index = torch.cat(window_index)
# compute reverse indices
reverse_indices = self.invert_permutation(window_index)
cu_window_seqlens = torch.cat(cu_window_seqlens)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
cu_seqlens = torch.cat(cu_seqlens)
......@@ -813,7 +824,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
# adapter
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment