Unverified Commit 4ed9053e authored by timmy-feng's avatar timmy-feng Committed by GitHub
Browse files

Remove mrope position sync (#9460)


Co-authored-by: default avatarNathan Wang <nathan.r.wang@gmail.com>
parent 5e19b159
......@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
return position_ids, mrope_position_deltas
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> torch.Tensor:
return torch.tensor(
[
list(
range(
context_len + mrope_position_delta,
seq_len + mrope_position_delta,
)
)
for _ in range(3)
]
)
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
......
......@@ -516,24 +516,23 @@ class ForwardBatch:
for batch_idx in range(batch_size):
mm_input = batch.multimodal_inputs[batch_idx]
if self.forward_mode.is_decode():
mrope_position_deltas = (
[0]
if mm_input is None
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
)
next_input_positions = []
for mrope_position_delta in mrope_position_deltas:
# batched deltas needs to be processed separately
# Convert list of lists to tensor with shape [3, seq_len]
next_input_positions += [
MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
int(self.seq_lens[batch_idx]) - 1,
int(self.seq_lens[batch_idx]),
)
]
# 3 * N
mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
if mm_input is None:
mrope_positions_list[batch_idx] = torch.full(
(3, 1),
self.seq_lens[batch_idx] - 1,
dtype=torch.int64,
device=model_runner.device,
)
else:
mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(
model_runner.device, non_blocking=True
)
mrope_positions_list[batch_idx] = (
(mrope_position_deltas + self.seq_lens[batch_idx] - 1)
.unsqueeze(0)
.repeat(3, 1)
)
elif self.forward_mode.is_extend():
extend_seq_len, extend_prefix_len = (
batch.extend_seq_lens[batch_idx],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment