Commit 9e93852d authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

upgrade pytorch-lightning version to 1.8.6

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/453

Previous diffs updated the LRScheduler to public version (eg. https://github.com/facebookresearch/detectron2/pull/4709), this also requires newer version of pytorch-lightning. This diff upgrades the lightning version to 1.8.6, also fixes some deprecated call sites of old lightning versions.
- `deepcopy` seems to be supported now, remove `_deepcopy` (there's now not allowed to access `trainer` attributed when it is `None`)
- `dataloader_idx` is removed from `on_train_batch_start`.
- stop using `_accelerator_connector` (the AcceleratorConnector doesn't have those attributes anymore).
- deprecated `on_pretrain_routine_end` -> `on_fit_start`

Reviewed By: YanjunChen329

Differential Revision: D42319019

fbshipit-source-id: ba46abbd98da96783e15d187a361fda47dc7d4d6
parent 2246aba3
...@@ -54,18 +54,6 @@ def rhasattr(obj: Any, attr: str, *args) -> bool: ...@@ -54,18 +54,6 @@ def rhasattr(obj: Any, attr: str, *args) -> bool:
return True return True
def _deepcopy(pl_module: LightningModule) -> LightningModule:
"""Copy a LightningModule. Some properties need to be ignored."""
# Remove trainer reference.
trainer = pl_module.trainer
try:
pl_module.trainer = None
copy = deepcopy(pl_module)
finally:
pl_module.trainer = trainer
return copy
def _quantized_forward(self, *args, **kwargs): def _quantized_forward(self, *args, **kwargs):
"""Forward method for a quantized module.""" """Forward method for a quantized module."""
if not self.training and hasattr(self, "_quantized"): if not self.training and hasattr(self, "_quantized"):
...@@ -100,7 +88,7 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool: ...@@ -100,7 +88,7 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool:
def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]): def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]):
if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED): if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED):
# model has been prepared for QAT before saving into checkpoint # model has been prepared for QAT before saving into checkpoint
copied = _deepcopy(model) copied = deepcopy(model)
prepared = prepare_fake_quant_model(copied.cfg, copied.model, is_qat=True) prepared = prepare_fake_quant_model(copied.cfg, copied.model, is_qat=True)
copied.model = prepared copied.model = prepared
setattr(model, PREPARED, copied) setattr(model, PREPARED, copied)
...@@ -465,7 +453,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -465,7 +453,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
with mode(pl_module, training=True) as train: with mode(pl_module, training=True) as train:
prepared = self.prepare( prepared = self.prepare(
_deepcopy(train), deepcopy(train),
configs=self.qconfig_dicts, configs=self.qconfig_dicts,
attrs=self.preserved_attrs, attrs=self.preserved_attrs,
) )
...@@ -483,7 +471,6 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -483,7 +471,6 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
pl_module: LightningModule, pl_module: LightningModule,
batch: Any, batch: Any,
batch_idx: int, batch_idx: int,
dataloader_idx: int,
) -> None: ) -> None:
"""Applies model transforms at as specified during training.""" """Applies model transforms at as specified during training."""
apply_only_once = [] apply_only_once = []
...@@ -603,7 +590,7 @@ class PostTrainingQuantization(Callback, QuantizationMixin): ...@@ -603,7 +590,7 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
""" """
# Pass a copy to quantization APIs. # Pass a copy to quantization APIs.
self.prepared = self.prepare( self.prepared = self.prepare(
_deepcopy(pl_module).eval(), deepcopy(pl_module).eval(),
configs=self.qconfig_dicts, configs=self.qconfig_dicts,
attrs=self.preserved_attrs, attrs=self.preserved_attrs,
) )
......
...@@ -25,6 +25,7 @@ from d2go.runner.default_runner import ( ...@@ -25,6 +25,7 @@ from d2go.runner.default_runner import (
from d2go.utils.ema_state import EMAState from d2go.utils.ema_state import EMAState
from d2go.utils.misc import get_tensorboard_log_dir from d2go.utils.misc import get_tensorboard_log_dir
from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler
from pytorch_lightning.strategies import DDPStrategy, SingleDeviceStrategy
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.logger import _flatten_dict from pytorch_lightning.utilities.logger import _flatten_dict
...@@ -274,10 +275,10 @@ class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule): ...@@ -274,10 +275,10 @@ class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule):
def _reset_dataset_evaluators(self): def _reset_dataset_evaluators(self):
"""reset validation dataset evaluator to be run in EVAL_PERIOD steps""" """reset validation dataset evaluator to be run in EVAL_PERIOD steps"""
assert ( assert isinstance(self.trainer.strategy, (SingleDeviceStrategy, DDPStrategy)), (
len(self.trainer._accelerator_connector.parallel_devices) == 1 "Only Single Device or DDP strategies are supported,"
or self.trainer._accelerator_connector.use_ddp f" instead found: {self.trainer.strategy}"
), "Only DDP and DDP_CPU distributed backend are supported" )
def _get_inference_dir_name( def _get_inference_dir_name(
base_dir, inference_type, dataset_name, model_tag: ModelTag base_dir, inference_type, dataset_name, model_tag: ModelTag
...@@ -391,7 +392,7 @@ class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule): ...@@ -391,7 +392,7 @@ class DefaultTask(D2GoDataAPIMixIn, pl.LightningModule):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Hooks # Hooks
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def on_pretrain_routine_end(self) -> None: def on_fit_start(self) -> None:
if self.cfg.MODEL_EMA.ENABLED: if self.cfg.MODEL_EMA.ENABLED:
if self.ema_state and self.ema_state.has_inited(): if self.ema_state and self.ema_state.has_inited():
# ema_state could have been loaded from checkpoint # ema_state could have been loaded from checkpoint
......
...@@ -27,7 +27,7 @@ requirements = [ ...@@ -27,7 +27,7 @@ requirements = [
"Pillow", "Pillow",
"mock", "mock",
"torch", "torch",
"pytorch-lightning @ git+https://github.com/PyTorchLightning/pytorch-lightning@9b011606f", "pytorch-lightning==1.8.6",
"opencv-python", "opencv-python",
"parameterized", "parameterized",
# Downgrade the protobuf package to 3.20.x or lower, related: # Downgrade the protobuf package to 3.20.x or lower, related:
......
...@@ -131,9 +131,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -131,9 +131,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
f"step={step}", f"step={step}",
) )
trainer.fit_loop.global_step = step trainer.fit_loop.global_step = step
qat.on_train_batch_start( qat.on_train_batch_start(trainer, module, batch=None, batch_idx=0)
trainer, module, batch=None, batch_idx=0, dataloader_idx=0
)
self.assertEqual( self.assertEqual(
len( len(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment