Unverified Commit 04428160 authored by Pablo Montalvo's avatar Pablo Montalvo Committed by GitHub
Browse files

Fix generate with `inputs_embeds` as input (#32493)

* I think inputs_embeds has ndim == 3

* fix sequence length catch

* add generate test

* [run-slow]olmo, persimmon, gemma, gemma2, qwen2, llama

* skip whisper

* fix bart test

* more fixes
parent b01f9c48
...@@ -756,17 +756,18 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): ...@@ -756,17 +756,18 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1132,17 +1132,18 @@ class CohereForCausalLM(CoherePreTrainedModel): ...@@ -1132,17 +1132,18 @@ class CohereForCausalLM(CoherePreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1403,17 +1403,18 @@ class DbrxForCausalLM(DbrxPreTrainedModel): ...@@ -1403,17 +1403,18 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1270,17 +1270,18 @@ class FalconForCausalLM(FalconPreTrainedModel): ...@@ -1270,17 +1270,18 @@ class FalconForCausalLM(FalconPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1143,17 +1143,18 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1143,17 +1143,18 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -104,7 +104,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( ...@@ -104,7 +104,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype padding_mask, min_dtype
) )
return causal_mask return causal_mask
...@@ -301,7 +300,6 @@ class Gemma2Attention(nn.Module): ...@@ -301,7 +300,6 @@ class Gemma2Attention(nn.Module):
attn_weights = attn_weights / self.config.attn_logit_softcapping attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights) attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping attn_weights = attn_weights * self.config.attn_logit_softcapping
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
...@@ -501,11 +499,9 @@ class Gemma2SdpaAttention(Gemma2Attention): ...@@ -501,11 +499,9 @@ class Gemma2SdpaAttention(Gemma2Attention):
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask causal_mask = attention_mask
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None: if query_states.device.type == "cuda" and causal_mask is not None:
...@@ -516,7 +512,6 @@ class Gemma2SdpaAttention(Gemma2Attention): ...@@ -516,7 +512,6 @@ class Gemma2SdpaAttention(Gemma2Attention):
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
...@@ -581,7 +576,6 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -581,7 +576,6 @@ class Gemma2DecoderLayer(nn.Module):
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :] attention_mask = attention_mask[:, :, :, -self.sliding_window :]
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -1013,7 +1007,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel): ...@@ -1013,7 +1007,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
...@@ -1080,7 +1073,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel): ...@@ -1080,7 +1073,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
input_ids = input_ids[:, -cache_position.shape[0] :] input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position] input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation # create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
...@@ -1096,22 +1088,20 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel): ...@@ -1096,22 +1088,20 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
# The clone here is for the same reason as for `position_ids`. # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2: if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask, attention_mask,
sequence_length=sequence_length, sequence_length=sequence_length,
...@@ -1122,7 +1112,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel): ...@@ -1122,7 +1112,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
cache_position=cache_position, cache_position=cache_position,
batch_size=batch_size, batch_size=batch_size,
) )
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids, "position_ids": position_ids,
......
...@@ -970,17 +970,18 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ...@@ -970,17 +970,18 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1220,17 +1220,18 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): ...@@ -1220,17 +1220,18 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.embed_out.weight.dtype dtype = self.embed_out.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1100,17 +1100,18 @@ class GPTJForCausalLM(GPTJPreTrainedModel): ...@@ -1100,17 +1100,18 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1265,17 +1265,18 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1265,17 +1265,18 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1136,17 +1136,18 @@ class NemotronForCausalLM(NemotronPreTrainedModel): ...@@ -1136,17 +1136,18 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1176,17 +1176,18 @@ class OlmoForCausalLM(OlmoPreTrainedModel): ...@@ -1176,17 +1176,18 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -993,17 +993,18 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): ...@@ -993,17 +993,18 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1278,17 +1278,18 @@ class PhiForCausalLM(PhiPreTrainedModel): ...@@ -1278,17 +1278,18 @@ class PhiForCausalLM(PhiPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1318,17 +1318,18 @@ class Phi3ForCausalLM(Phi3PreTrainedModel): ...@@ -1318,17 +1318,18 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1176,17 +1176,18 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel): ...@@ -1176,17 +1176,18 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1372,17 +1372,18 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel): ...@@ -1372,17 +1372,18 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1271,17 +1271,18 @@ class StableLmForCausalLM(StableLmPreTrainedModel): ...@@ -1271,17 +1271,18 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1152,17 +1152,18 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): ...@@ -1152,17 +1152,18 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None: if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length = inputs_embeds.shape batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = inputs_embeds.device device = model_inputs["inputs_embeds"].device
else: else:
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = model_inputs["input_ids"].shape
device = input_ids.device device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
......
...@@ -1540,3 +1540,8 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un ...@@ -1540,3 +1540,8 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
@unittest.skip @unittest.skip
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
@unittest.skip(reason="Generate needs input ids")
def test_inputs_embeds_matches_input_ids_with_generate(self):
# generate only works with input ids for bartforcausalLM
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment