"tests/pytorch/vscode:/vscode.git/clone" did not exist on "97fbd94dfc436ce8500770adced4a5f57889347c"
Unverified Commit 0b3b3e9a authored by ash-sigh's avatar ash-sigh Committed by GitHub
Browse files

transfer mrope_position_delta to device when first running (#11047)

parent a1d5bc4c
......@@ -575,9 +575,15 @@ class ForwardBatch:
device=model_runner.device,
)
else:
mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(
model_runner.device, non_blocking=True
)
if mm_input.mrope_position_delta.device.type != model_runner.device:
# transfer mrope_position_delta to device when the first running,
# avoiding successvie host-to-device data transfer
mm_input.mrope_position_delta = (
mm_input.mrope_position_delta.to(
model_runner.device, non_blocking=True
)
)
mrope_position_deltas = mm_input.mrope_position_delta.flatten()
mrope_positions_list[batch_idx] = (
(mrope_position_deltas + self.seq_lens[batch_idx] - 1)
.unsqueeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment