• Suraj Patil's avatar
    [Flax] improve large model init and loading (#16148) · d3bd9ac7
    Suraj Patil authored
    
    
    * begin do_init
    
    * add params_shape_tree
    
    * raise error if params are accessed when do_init is False
    
    * don't allow do_init=False when keys are missing
    
    * make shape tree a property
    
    * assign self._params at the end
    
    * add test for do_init
    
    * add do_init arg to all flax models
    
    * fix param setting
    
    * disbale do_init for composite models
    
    * update test
    
    * add do_init in FlaxBigBirdForMultipleChoice
    
    * better names and errors
    
    * improve test
    
    * style
    
    * add a warning when do_init=False
    
    * remove extra if
    
    * set params after _required_params
    
    * add test for from_pretrained
    
    * do_init => _do_init
    
    * chage warning to info
    
    * fix typo
    
    * add params in init_weights
    
    * add params to gpt neo init
    
    * add params to init_weights
    
    * update do_init test
    
    * Trigger CI
    
    * Apply suggestions from code review
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    
    * update template
    
    * trigger CI
    
    * style
    
    * style
    
    * fix template
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    d3bd9ac7
test_modeling_flax_common.py 46.4 KB