• Nicolas Patry's avatar
    Cleaning up `ConversationalPipeline` to support more than DialoGPT. (#10002) · b1aa4982
    Nicolas Patry authored
    * Cleaning up `ConversationalPipeline` to support more than DialoGPT.
    
    Currently ConversationalPipeline was heavily biased towards DialoGPT
    ,which is the default model for this pipeline.
    
    This PR proposes changes to put back the modifications specific to
    DialoGPT into tokenizer-specific behavior wherever possible, by
    creating `_build_conversation_input_ids` function that takes
    conversation as input, and returns a list of ints corresponding
    to the tokens. It feels natural to put here because all models
    have probably different strategies to build input_ids from the
    full conversation and it's the tokenizer's job to transform strings
    into tokens (and vice-versa)
    
    If `_build_conversation_input_ids` is missing, previous behavior is
    used so we don't break anything so far (except for blenderbot where it's a fix).
    
    This PR also contains a fix for too long inputs. There used
    to be dead code for trying to limit the size of incoming input.
    The introduced fixed is that we limit
    within `_build_conversation_input_ids` to `tokenizer.model_max_length`.
    It corresponds to the intent of the removed dead code and is actually
    better because it corresponds to `model_max_length` which is different
    from `max_length` (which is a default parameter for `generate`).
    
    - Removed `history` logic from the Conversation as it's not relevant
    anymore because tokenization logic has been moved to tokenizer.
    And tokenizer cannot save any cache, and conversation cannot know
    what is relevant or not.
    Also it's not usable from `blenderbot` because the input_ids are
    not append only (EOS tokens is always at the end).
    
    - Added `iter_texts` method on `Conversation` because all
    the code was literred with some form of this iteration of
    past/generated_responses.
    
    * Removing torch mention in types.
    
    * Adding type checking to `_build_conversation_input_ids`.
    
    * Fixing import in strings.
    b1aa4982
test_pipelines_conversational.py 16.1 KB