Unverified Commit e8640ee9 authored by Vincent Zhong's avatar Vincent Zhong Committed by GitHub
Browse files

[smol] [perf] Inverse perm improvement (#11482)


Signed-off-by: default avatarvincentzed <207368749+vincentzed@users.noreply.github.com>
parent d0a64c7e
......@@ -59,6 +59,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.utils import permute_inv
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor
......@@ -405,6 +406,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
# Move window_index to the same device as x before using it to index x
window_index = window_index.to(device=x.device)
reverse_indices = permute_inv(window_index)
# Ensure rotary_pos_emb is on the same device/dtype as x
rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
......@@ -451,8 +453,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
# adapter
x = self.merger(x)
reverse_indices = torch.argsort(window_index)
x = x[reverse_indices, :]
return x
......
......@@ -53,3 +53,9 @@ def create_fused_set_kv_buffer_arg(
v_scale=layer.v_scale,
cache_loc=forward_batch.out_cache_loc,
)
def permute_inv(perm: torch.Tensor) -> torch.Tensor:
inv_perm = torch.empty_like(perm)
inv_perm[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype)
return inv_perm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment