Unverified Commit 941b7fc0 authored by chenxiao's avatar chenxiao Committed by GitHub
Browse files

Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) (#11763)



* Avoid creating tensor in CosmosAttnProcessor2_0 (#11761)

* up

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent 76a62ac9
...@@ -187,9 +187,15 @@ class CosmosAttnProcessor2_0: ...@@ -187,9 +187,15 @@ class CosmosAttnProcessor2_0:
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
# 4. Prepare for GQA # 4. Prepare for GQA
if torch.onnx.is_in_onnx_export():
query_idx = torch.tensor(query.size(3), device=query.device) query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device) key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device) value_idx = torch.tensor(value.size(3), device=value.device)
else:
query_idx = query.size(3)
key_idx = key.size(3)
value_idx = value.size(3)
key = key.repeat_interleave(query_idx // key_idx, dim=3) key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3) value = value.repeat_interleave(query_idx // value_idx, dim=3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment