• Billy Bradley's avatar
    In assisted decoding, pass model_kwargs to model's forward call (fix... · dcc49d8a
    Billy Bradley authored
    In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) (#25242)
    
    * In assisted decoding, pass model_kwargs to model's forward call
    
    Previously, assisted decoding would ignore any additional kwargs
    that it doesn't explicitly handle. This was inconsistent with other
    generation methods, which pass the model_kwargs through
    prepare_inputs_for_generation and forward the returned dict to the
    model's forward call.
    
    The prepare_inputs_for_generation method needs to be amended in all
    models, as previously it only kept the last input ID when a past_key_values
    was passed.
    
    * Improve variable names in _extend_attention_mask
    
    * Refactor extending token_type_ids into a function
    
    * Replace deepcopy with copy to optimize performance
    
    * Update new persimmon model with llama changes for assisted generation
    
    * Update new mistral model for assisted generation with prepare_inputs_for_generation
    
    * Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
    dcc49d8a
modeling_gptj.py 49.1 KB