Unverified Commit 6e322052 authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: create position ids for text only input (#2714)

* fix: create position ids for text only input

* fix: prefer repeat over expand to avoid clone
parent 01dacf8e
......@@ -468,7 +468,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[:, i, :] = llm_positions.to(position_ids.device)
else:
position_ids = (
torch.arange(batch_input_ids.shape[1], device=batch_input_ids.device)
.view(1, 1, -1)
.repeat(3, batch_input_ids.shape[0], 1)
)
return position_ids
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment