Unverified Commit b1dc87a0 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models][Gemma4] Prevent GPU/CPU sync in `embed_input_ids` (#39234)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 79a5b632
...@@ -1254,9 +1254,10 @@ class Gemma4ForConditionalGeneration( ...@@ -1254,9 +1254,10 @@ class Gemma4ForConditionalGeneration(
# computation (using token_type_ids == 0 as text_mask). # computation (using token_type_ids == 0 as text_mask).
# Replicate this: map image token positions to token 0. # Replicate this: map image token positions to token 0.
if is_multimodal is not None: if is_multimodal is not None:
is_multimodal = is_multimodal.to(input_ids.device)
ple_input_ids = torch.where( ple_input_ids = torch.where(
is_multimodal, torch.zeros_like(input_ids), input_ids is_multimodal.to(input_ids.device, non_blocking=True),
torch.zeros_like(input_ids),
input_ids,
) )
else: else:
ple_input_ids = input_ids ple_input_ids = input_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment