Unverified Commit 4ed9053e authored by timmy-feng's avatar timmy-feng Committed by GitHub
Browse files

Remove mrope position sync (#9460)


Co-authored-by: default avatarNathan Wang <nathan.r.wang@gmail.com>
parent 5e19b159
...@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
return position_ids, mrope_position_deltas return position_ids, mrope_position_deltas
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> torch.Tensor:
return torch.tensor(
[
list(
range(
context_len + mrope_position_delta,
seq_len + mrope_position_delta,
)
)
for _ in range(3)
]
)
class DualChunkRotaryEmbedding(CustomOp): class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention.""" """Rotary positional embedding for Dual Chunk Attention."""
......
...@@ -516,24 +516,23 @@ class ForwardBatch: ...@@ -516,24 +516,23 @@ class ForwardBatch:
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
mm_input = batch.multimodal_inputs[batch_idx] mm_input = batch.multimodal_inputs[batch_idx]
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
mrope_position_deltas = ( # 3 * N
[0] if mm_input is None:
if mm_input is None mrope_positions_list[batch_idx] = torch.full(
else flatten_nested_list(mm_input.mrope_position_delta.tolist()) (3, 1),
self.seq_lens[batch_idx] - 1,
dtype=torch.int64,
device=model_runner.device,
)
else:
mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(
model_runner.device, non_blocking=True
) )
next_input_positions = [] mrope_positions_list[batch_idx] = (
for mrope_position_delta in mrope_position_deltas: (mrope_position_deltas + self.seq_lens[batch_idx] - 1)
# batched deltas needs to be processed separately .unsqueeze(0)
# Convert list of lists to tensor with shape [3, seq_len] .repeat(3, 1)
next_input_positions += [
MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
int(self.seq_lens[batch_idx]) - 1,
int(self.seq_lens[batch_idx]),
) )
]
# 3 * N
mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
elif self.forward_mode.is_extend(): elif self.forward_mode.is_extend():
extend_seq_len, extend_prefix_len = ( extend_seq_len, extend_prefix_len = (
batch.extend_seq_lens[batch_idx], batch.extend_seq_lens[batch_idx],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment