• Andy Ehrenberg's avatar
    add flax whisper implementation (#20479) · 2840272c
    Andy Ehrenberg authored
    
    
    * add flax whisper implementation
    
    * rever change to setup
    
    * remove unused imports
    
    * revert generation changes
    
    * flax whisper docs
    
    * docs
    
    * import order
    
    * import sorting
    
    * isort
    
    * add dummy objects
    
    * doc formatting
    
    * formatting
    
    * remove trailing whitespaces
    
    * fix flax whisper docs
    
    * add generation logic to unlock flax whisper
    
    * remove scans
    
    * give credits to Flax Bart implementation
    
    * remove unused imports
    
    * add license
    
    * remove assert
    
    * more credits to Bart
    
    * fix style
    
    * formatting
    
    * support left padding
    
    * add flax whisper generation test
    
    * remove copied from comments whenever not a full copy
    
    * fix docstrings for logits processors
    
    * revert change to FlaxForceTokensLogitsProcessor
    
    * revert doc changes
    
    * improve generation docs
    
    * reorganize
    
    * formatting
    
    * cleanup docs
    
    * add tests
    
    * handle empty list case
    
    * fix forced decoder ids in flax tests
    
    * add flax whisper to inits
    
    * upate dummy objects
    
    * docs for FlaxAutoModelForSpeechSeq2Seq
    
    * fix decoder_position_ids computation in pretrained model decode/__call__ fns
    
    * add Copied from statements as necessary
    
    * compute position_ids only in __call__ and decode methods of pretrained model subclasses
    
    * improve readabilityof compute positional embeddings
    
    * check dimensionality of input_features instead of hidden_states
    
    * copied from statement for init_cache
    
    * formatting
    
    * fix copies
    
    * fix copies
    
    * pass attention mask to encoder layers
    
    * fix decoder module outputs
    
    * set dtype
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * smaller flax model for whisper test
    
    * Update src/transformers/generation/flax_utils.py
    Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
    
    * Update src/transformers/models/whisper/modeling_flax_whisper.py
    Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
    
    * Update tests/models/whisper/test_modeling_flax_whisper.py
    Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
    
    * cleanup
    Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
    
    * Update src/transformers/models/whisper/modeling_flax_whisper.py
    Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
    
    * bias cleanup
    
    * doc fix
    
    * align style for force tokens processor
    
    * readability
    
    * fix input shape in tests
    
    * revert FlaxGenerationMixin docstring
    
    * formatting
    
    * fix tests
    
    * fix imports
    
    * consistent encoder hidden states
    
    * consistent hidden states
    
    * input shapes
    
    * typo
    
    * partial class trick
    
    * partial class for input shape
    
    * base_class with correct input shape
    
    * partial base classes
    
    * match by name
    
    * set main_input_name
    
    * compare on names
    
    * formatting
    
    * remove unused import
    
    * safer position ids computation
    
    * safer position id computation
    
    * Update src/transformers/models/whisper/modeling_flax_whisper.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * Update src/transformers/models/whisper/modeling_flax_whisper.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * remove identical inherited tests
    
    * fix prompt ids in tests
    
    * use generation config
    
    * use jnp array
    
    * better var names
    
    * more explicit bias use
    
    * import transformers
    
    * formatting
    
    * test formatting
    
    * remove unused imports
    
    * remove unused imports
    
    * formatting
    
    * isort
    
    * docs
    
    * fix ln orders for encoder hidden states
    
    * whisper unique generation stuff
    
    * flake
    
    * use finfo for attention bias
    
    * docs
    
    * Update src/transformers/generation/flax_utils.py
    Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
    
    * docs
    
    * add timestamp flax test
    
    * jit for timestamps
    
    * formatting
    
    * clean up timestamps processor
    
    * formatting
    
    * remove if_true
    
    * cleanup
    
    ---------
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
    Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
    2840272c
whisper.mdx 3.16 KB