• Suraj Patil's avatar
    Flax CLM script (#12023) · 15b498f3
    Suraj Patil authored
    * first draft
    
    * max_seq_length => block_size
    
    * fix arg names
    
    * fix typos
    
    * fix loss calculation
    
    * add max examples, fix  train eval steps, metrics
    
    * optimizer mask
    
    * fix perpelexity, metric logging
    
    * fix logging
    
    * data_collator = > data_loader
    
    * refactor loss_fn
    
    * support single GPU
    
    * pass distributed to write_metric
    
    * fix jitting
    
    * fix single device training
    
    * fix single device metrics
    
    * close inner progress bars once finished
    
    * add overwrite_cache arg
    
    * ifx dataset caching issue
    
    * add more logs
    
    * few small fixes,
    
    * address nicholas suggestions
    
    * fix docstr
    
    * address patricks suggestions
    
    * make flake happy
    
    * pass new new_dropout_rng to apply_gradients
    
    * reset train metrics after every epoc
    
    * remove distributed logis, small fixes
    15b498f3
run_clm_flax.py 25.2 KB