Unverified Commit b1dc87a0 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models][Gemma4] Prevent GPU/CPU sync in `embed_input_ids` (#39234)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 79a5b632
......@@ -1254,9 +1254,10 @@ class Gemma4ForConditionalGeneration(
# computation (using token_type_ids == 0 as text_mask).
# Replicate this: map image token positions to token 0.
if is_multimodal is not None:
is_multimodal = is_multimodal.to(input_ids.device)
ple_input_ids = torch.where(
is_multimodal, torch.zeros_like(input_ids), input_ids
is_multimodal.to(input_ids.device, non_blocking=True),
torch.zeros_like(input_ids),
input_ids,
)
else:
ple_input_ids = input_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment