• Yongqiang Wang's avatar
    reduce memory footprint for average_checkpoints (#647) · d63477e1
    Yongqiang Wang authored
    Summary:
    Pull Request resolved: https://github.com/pytorch/fairseq/pull/647
    
    the current implementation of average_checkpoints requires loading all
    the model parameters into memory and then do the averaging. To average large
    models (e.g., transformer) over a large number of checkpoints (e.g., >50),
    it may require over 100GB memory.
    
    Loading all the parameters is not necessary, as we know the number of models in advance.
    
    Reviewed By: skritika
    
    Differential Revision: D15027513
    
    fbshipit-source-id: 0afe37c9a031a9ab0f1e78844a37be49ec5f76f1
    d63477e1
average_checkpoints.py 4.93 KB