Commit 8b03f9aa authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

Fix unused param in QAT training

Summary:
In quantization callback, we prepare the model with FX quantization API and only use the prepared model in training.
However, when training in DDP, the parameters in the origin model still require grad, causing unused parameters RuntimeError.
Previously, Lightning trainer train the model with find_unused_param flag, but if user manually disable it, they will get the runtime error.

In this diff, the parameters in the origin model are frozen. We could consider deleting the origin model after preparation to save memory, but we might have to make some assumption on Lightning module structure, for example, `.model` is the origin model, so that we could `delattr(pl_module, "model")`.

Reviewed By: wat3rBro

Differential Revision: D31902368

fbshipit-source-id: 56eabb6b2296278529dd2b94d6aa4c9ec9e9ca6b
parent 39054767
......@@ -449,11 +449,16 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
return
with mode(pl_module, training=True) as train:
pl_module._prepared = self.prepare(
prepared = self.prepare(
_deepcopy(train),
configs=self.qconfig_dicts,
attrs=self.preserved_attrs,
)
# freeze the original model since only the prepared model will
# participate in forward.
for x in train.parameters():
x.requires_grad = False
pl_module._prepared = prepared
pl_module.forward = MethodType(_quantized_forward, pl_module)
self.prepared = pl_module._prepared
......
......@@ -226,7 +226,11 @@ class DefaultTask(pl.LightningModule):
def configure_optimizers(
self,
) -> Tuple[List[torch.optim.Optimizer], List]:
optim = build_optimizer_mapper(self.cfg, self.model)
model = self.model
if hasattr(self, PREPARED):
# train the prepared model for FX quantization
model = getattr(self, PREPARED)
optim = build_optimizer_mapper(self.cfg, model)
lr_scheduler = d2_build_lr_scheduler(self.cfg, optim)
return [optim], [{"scheduler": lr_scheduler, "interval": "step"}]
......
......@@ -20,6 +20,7 @@ from d2go.utils.testing.helper import tempdir
from d2go.utils.testing.lightning_test_module import TestModule
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.plugins import DDPPlugin
from torch.ao.quantization import ( # @manual; @manual
default_dynamic_qconfig,
get_default_qconfig,
......@@ -189,11 +190,14 @@ class TestQuantizationAwareTraining(unittest.TestCase):
num_epochs = 2
qat = QuantizationAwareTraining()
trainer = Trainer(
accelerator="ddp_cpu",
num_processes=1,
default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False,
callbacks=[qat],
max_epochs=num_epochs,
logger=False,
plugins=[DDPPlugin(find_unused_parameters=False)],
)
trainer.fit(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment