• Sanchit Gandhi's avatar
    Add bloom flax (#25094) · e9310363
    Sanchit Gandhi authored
    
    
    * First commit
    
    * step 1 working
    
    * add alibi
    
    * placeholder for `scan`
    
    * add matrix mult alibi
    
    * beta scaling factor for bmm
    
    * working v1 - simple forward pass
    
    * move layer_number from attribute to arg in call
    
    * partial functioning scan
    
    * hacky working scan
    
    * add more modifs
    
    * add test
    
    * update scan for new kwarg order
    
    * fix position_ids problem
    
    * fix bug in attention layer
    
    * small fix
    
    - do the alibi broadcasting only once
    
    * prelim refactor
    
    * finish refactor
    
    * alibi shifting
    
    * incorporate dropout_add to attention module
    
    * make style
    
    * make padding work again
    
    * update
    
    * remove bogus file
    
    * up
    
    * get generation to work
    
    * clean code a bit
    
    * added small tests
    
    * adding albii test
    
    * make CI tests pass:
    
    - change init weight
    - add correct tuple for output attention
    - add scan test
    - make CI tests work
    
    * fix few nits
    
    * fix nit onnx
    
    * fix onnx nit
    
    * add missing dtype args to nn.Modules
    
    * remove debugging statements
    
    * fix scan generate
    
    * Update modeling_flax_bloom.py
    
    * Update test_modeling_flax_bloom.py
    
    * Update test_modeling_flax_bloom.py
    
    * Update test_modeling_flax_bloom.py
    
    * fix small test issue + make style
    
    * clean up
    
    * Update tests/models/bloom/test_modeling_flax_bloom.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * fix function name
    
    * small fix test
    
    * forward contrib credits from PR17761
    
    * Fix failing test
    
    * fix small typo documentation
    
    * fix non passing test
    
    - remove device from build alibi
    
    * refactor call
    
    - refactor `FlaxBloomBlockCollection` module
    
    * make style
    
    * upcast to fp32
    
    * cleaner way to upcast
    
    * remove unused args
    
    * remove layer number
    
    * fix scan test
    
    * make style
    
    * fix i4 casting
    
    * fix slow test
    
    * Update src/transformers/models/bloom/modeling_flax_bloom.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * remove `layer_past`
    
    * refactor a bit
    
    * fix `scan` slow test
    
    * remove useless import
    
    * major changes
    
    - remove unused code
    - refactor a bit
    - revert import `torch`
    
    * major refactoring
    
    - change build alibi
    
    * remove scan
    
    * fix tests
    
    * make style
    
    * clean-up alibi
    
    * add integration tests
    
    * up
    
    * fix batch norm conversion
    
    * style
    
    * style
    
    * update pt-fx cross tests
    
    * update copyright
    
    * Update src/transformers/modeling_flax_pytorch_utils.py
    Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
    
    * per-weight check
    
    * style
    
    * line formats
    
    ---------
    Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
    Co-authored-by: default avatarhaileyschoelkopf <haileyschoelkopf@users.noreply.github.com>
    Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
    e9310363
index.md 59.4 KB