Unverified Commit 594c1610 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Mamba & RecurrentGemma: enable strict signature (#31549)

* enable strict signature

* this should not have been deleted

* recurrent_gemma too
parent ae9dd02e
...@@ -2692,13 +2692,12 @@ class GenerationMixin: ...@@ -2692,13 +2692,12 @@ class GenerationMixin:
# prepare model inputs # prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
# forward pass to get next token # forward pass to get next token
outputs = self( outputs = self(**model_inputs, return_dict=True)
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished: if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need continue # don't waste resources running the code we don't need
...@@ -2919,6 +2918,10 @@ class GenerationMixin: ...@@ -2919,6 +2918,10 @@ class GenerationMixin:
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
# if sequential is True, split the input to batches of batch_size and run sequentially # if sequential is True, split the input to batches of batch_size and run sequentially
if sequential: if sequential:
if any( if any(
...@@ -2944,24 +2947,13 @@ class GenerationMixin: ...@@ -2944,24 +2947,13 @@ class GenerationMixin:
model_inputs, split_size=batch_size, full_batch_size=batch_beam_size model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
) )
outputs_per_sub_batch = [ outputs_per_sub_batch = [
self( self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches
**inputs_per_sub_batch,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
for inputs_per_sub_batch in inputs_per_sub_batches
] ]
outputs = stack_model_outputs(outputs_per_sub_batch) outputs = stack_model_outputs(outputs_per_sub_batch)
else: # Unchanged original behavior else: # Unchanged original behavior
outputs = self( outputs = self(**model_inputs, return_dict=True)
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished: if synced_gpus and this_peer_finished:
cur_len = cur_len + 1 cur_len = cur_len + 1
...@@ -3241,12 +3233,12 @@ class GenerationMixin: ...@@ -3241,12 +3233,12 @@ class GenerationMixin:
# do one decoder step on all beams of all sentences in batch # do one decoder step on all beams of all sentences in batch
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs, # prepare variable output controls (note: some models won't accept all output controls)
return_dict=True, model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
output_attentions=output_attentions, model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
output_hidden_states=output_hidden_states,
) outputs = self(**model_inputs, return_dict=True)
if synced_gpus and this_peer_finished: if synced_gpus and this_peer_finished:
cur_len = cur_len + 1 cur_len = cur_len + 1
...@@ -3522,12 +3514,11 @@ class GenerationMixin: ...@@ -3522,12 +3514,11 @@ class GenerationMixin:
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self( # prepare variable output controls (note: some models won't accept all output controls)
**model_inputs, model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
return_dict=True, model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
output_attentions=output_attentions,
output_hidden_states=output_hidden_states, outputs = self(**model_inputs, return_dict=True)
)
if synced_gpus and this_peer_finished: if synced_gpus and this_peer_finished:
cur_len = cur_len + 1 cur_len = cur_len + 1
...@@ -3793,11 +3784,11 @@ class GenerationMixin: ...@@ -3793,11 +3784,11 @@ class GenerationMixin:
model_inputs["num_logits_to_keep"] = candidate_length + 1 model_inputs["num_logits_to_keep"] = candidate_length + 1
# 2.2. Run a forward pass on the candidate sequence # 2.2. Run a forward pass on the candidate sequence
outputs = self( # prepare variable output controls (note: some models won't accept all output controls)
**model_inputs, model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
output_attentions=output_attentions, model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
output_hidden_states=output_hidden_states,
) outputs = self(**model_inputs)
# 2.3. Process the new logits # 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
......
...@@ -545,7 +545,6 @@ class MambaModel(MambaPreTrainedModel): ...@@ -545,7 +545,6 @@ class MambaModel(MambaPreTrainedModel):
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
) -> Union[Tuple, MambaOutput]: ) -> Union[Tuple, MambaOutput]:
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -673,7 +672,6 @@ class MambaForCausalLM(MambaPreTrainedModel): ...@@ -673,7 +672,6 @@ class MambaForCausalLM(MambaPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, MambaCausalLMOutput]: ) -> Union[Tuple, MambaCausalLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
......
...@@ -684,7 +684,6 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel): ...@@ -684,7 +684,6 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithNoAttention]: ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -823,7 +822,6 @@ class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel): ...@@ -823,7 +822,6 @@ class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, CausalLMOutput]: ) -> Union[Tuple, CausalLMOutput]:
r""" r"""
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment