• Kai Zhang's avatar
    Fix unused param in QAT training · 8b03f9aa
    Kai Zhang authored
    Summary:
    In quantization callback, we prepare the model with FX quantization API and only use the prepared model in training.
    However, when training in DDP, the parameters in the origin model still require grad, causing unused parameters RuntimeError.
    Previously, Lightning trainer train the model with find_unused_param flag, but if user manually disable it, they will get the runtime error.
    
    In this diff, the parameters in the origin model are frozen. We could consider deleting the origin model after preparation to save memory, but we might have to make some assumption on Lightning module structure, for example, `.model` is the origin model, so that we could `delattr(pl_module, "model")`.
    
    Reviewed By: wat3rBro
    
    Differential Revision: D31902368
    
    fbshipit-source-id: 56eabb6b2296278529dd2b94d6aa4c9ec9e9ca6b
    8b03f9aa
test_runner_lightning_quantization.py 17 KB