Unverified Commit 0b3b3e9a authored by ash-sigh's avatar ash-sigh Committed by GitHub
Browse files

transfer mrope_position_delta to device when first running (#11047)

parent a1d5bc4c
...@@ -575,9 +575,15 @@ class ForwardBatch: ...@@ -575,9 +575,15 @@ class ForwardBatch:
device=model_runner.device, device=model_runner.device,
) )
else: else:
mrope_position_deltas = mm_input.mrope_position_delta.flatten().to( if mm_input.mrope_position_delta.device.type != model_runner.device:
# transfer mrope_position_delta to device when the first running,
# avoiding successvie host-to-device data transfer
mm_input.mrope_position_delta = (
mm_input.mrope_position_delta.to(
model_runner.device, non_blocking=True model_runner.device, non_blocking=True
) )
)
mrope_position_deltas = mm_input.mrope_position_delta.flatten()
mrope_positions_list[batch_idx] = ( mrope_positions_list[batch_idx] = (
(mrope_position_deltas + self.seq_lens[batch_idx] - 1) (mrope_position_deltas + self.seq_lens[batch_idx] - 1)
.unsqueeze(0) .unsqueeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment