• Matthew Yu's avatar
    support registering layer losses to model · c4860c5b
    Matthew Yu authored
    Summary:
    Pull Request resolved: https://github.com/facebookresearch/d2go/pull/430
    
    We add losses in distillation by instantiating them in the distillation algorithm's init and then running them during the forward pass.
    
    However this has some issues:
    * the losses are not registered as a module in the model since they we organize them as a list of layerlossmetadata => this means that things like AMP do not behave as expected
    * the losses are not on the same device as the rest of the model since they are created potentially after the model is moved to a new device
    
    This diff solves both of these issues by including a helper function that registers and moves the losses to the same device as the model. `register_layer_losses_and_to_device` takes as input `List[LayerLossMetadata]`, moves the losses to the same device as the model and then registers these losses to the model.
    
    Differential Revision: D41296932
    
    fbshipit-source-id: ae7ae0847bce1b5cc481d838b9cae69cea424f25
    c4860c5b
test_modeling_distillation.py 24.5 KB