Commit 0848c589 authored by Luis Perez's avatar Luis Perez Committed by Facebook GitHub Bot
Browse files

Synchronize PyTorchLightning/pytorch-lightning (revision 7b283e3c@master) to...

Synchronize PyTorchLightning/pytorch-lightning (revision 7b283e3c@master) to github/third-party/PyTorchLightning/pytorch-lightning

Summary:
# Manual
 - remove fixme's in `model_checkpoint.py`, `parameter_monitor.py`, `test_quantization.py`, and `speed_monitor.py` now that `Trainer` is properly annotated.
- update `test_quantization.py` to `trainer.train_loop.global_step` instead of `trainer.global_step` which is a read-only.
- update `loop_callback.py` to read from `train_loop` for `batch_idx` (which is no longer available).

# Automatic
### New commit log messages
  7b283e3c Bugfix/Multiple dataloaders (#7433)
  d7c44cc6 Docs: sync chlog 1.3.1 (#7478)
  fdf50a5e Mark certain Trainer APIs as protected (#7420)
  ad9118f0 remove trainer hidden state | sanity refactor [1 / n] (#7437)
  4a1134db Log epoch metrics before firing the `on_evaluation_end` hook (#7272)
  b65ae794 Automatically check `DataModule.has_{setup,teardown,prepare_data}` [2/2] (#7238)
  8660d8cf [pre-commit.ci] pre-commit autoupdate (#7475)
  f6fe715e Fix Sphinx argument deprecation (#7464)

Reviewed By: shuyingsunshine21

Differential Revision: D28353491

fbshipit-source-id: 98b87d99e2f09b47b07270858fcbdb5d5299730b
parent 5aa5e3c8
...@@ -20,20 +20,18 @@ from d2go.utils.testing.helper import tempdir ...@@ -20,20 +20,18 @@ from d2go.utils.testing.helper import tempdir
from d2go.utils.testing.lightning_test_module import TestModule from d2go.utils.testing.lightning_test_module import TestModule
from pytorch_lightning import Trainer, seed_everything from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from torch.quantization import ( # @manual; @manual from torch.quantization import ( # @manual; @manual
default_dynamic_qconfig, default_dynamic_qconfig,
get_default_qconfig, get_default_qconfig,
) )
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
class TestUtilities(unittest.TestCase): class TestUtilities(unittest.TestCase):
""" Test some basic utilities we rely on. """ """Test some basic utilities we rely on."""
def test_get_set_has(self): def test_get_set_has(self):
""" Trivial test for generic behavior. Only support pre-existing deeply nested values.""" """Trivial test for generic behavior. Only support pre-existing deeply nested values."""
class TestObject(object): class TestObject(object):
def __init__(self): def __init__(self):
...@@ -54,10 +52,10 @@ class TestUtilities(unittest.TestCase): ...@@ -54,10 +52,10 @@ class TestUtilities(unittest.TestCase):
class TestModelTransform(unittest.TestCase): class TestModelTransform(unittest.TestCase):
""" Tests ModelTransforms. """ """Tests ModelTransforms."""
def test_invalid_construction_type_error(self): def test_invalid_construction_type_error(self):
""" Validate construction of ModelTransforms. Always have fn, msg, and one of [step, interval]. """ """Validate construction of ModelTransforms. Always have fn, msg, and one of [step, interval]."""
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
_ = ModelTransform() _ = ModelTransform()
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
...@@ -73,7 +71,7 @@ class TestModelTransform(unittest.TestCase): ...@@ -73,7 +71,7 @@ class TestModelTransform(unittest.TestCase):
) )
def test_positivity_value_error(self): def test_positivity_value_error(self):
""" Validates ModelTransforms are constructed with only valid arguments. """ """Validates ModelTransforms are constructed with only valid arguments."""
def identity(x): def identity(x):
return x return x
...@@ -88,7 +86,7 @@ class TestModelTransform(unittest.TestCase): ...@@ -88,7 +86,7 @@ class TestModelTransform(unittest.TestCase):
class TestQuantizationAwareTraining(unittest.TestCase): class TestQuantizationAwareTraining(unittest.TestCase):
def test_qat_misconfiguration(self): def test_qat_misconfiguration(self):
""" Tests failure when misconfiguring the QAT Callback. """ """Tests failure when misconfiguring the QAT Callback."""
invalid_params = [ invalid_params = [
{"start_step": -1}, {"start_step": -1},
{"enable_observer": (42, 42)}, {"enable_observer": (42, 42)},
...@@ -101,7 +99,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -101,7 +99,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
_ = QuantizationAwareTraining(**invalid_param) _ = QuantizationAwareTraining(**invalid_param)
def test_qat_transforms(self): def test_qat_transforms(self):
""" Tests the appropropriate ModelTransforms are defined with QAT.""" """Tests the appropropriate ModelTransforms are defined with QAT."""
qat = QuantizationAwareTraining( qat = QuantizationAwareTraining(
start_step=300, enable_observer=(350, 500), freeze_bn_step=550 start_step=300, enable_observer=(350, 500), freeze_bn_step=550
) )
...@@ -129,7 +127,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -129,7 +127,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
0, 0,
f"step={step}", f"step={step}",
) )
trainer.global_step = step trainer.train_loop.global_step = step
qat.on_train_batch_start( qat.on_train_batch_start(
trainer, module, batch=None, batch_idx=0, dataloader_idx=0 trainer, module, batch=None, batch_idx=0, dataloader_idx=0
) )
...@@ -153,7 +151,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -153,7 +151,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir @tempdir
def test_qat_interval_transform(self, root_dir): def test_qat_interval_transform(self, root_dir):
""" Tests an interval transform is applied multiple times. """ """Tests an interval transform is applied multiple times."""
seed_everything(100) seed_everything(100)
def linear_fn_counter(mod): def linear_fn_counter(mod):
...@@ -182,7 +180,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -182,7 +180,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir @tempdir
def test_module_quantized_during_train(self, root_dir): def test_module_quantized_during_train(self, root_dir):
""" Validate quantized aware training works as expected. """ """Validate quantized aware training works as expected."""
seed_everything(100) seed_everything(100)
model = TestModule() model = TestModule()
...@@ -219,7 +217,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -219,7 +217,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir @tempdir
def test_quantization_without_train(self, root_dir): def test_quantization_without_train(self, root_dir):
""" Validate quantization occurs even without a call to .fit() first. """ """Validate quantization occurs even without a call to .fit() first."""
seed_everything(100) seed_everything(100)
model = TestModule() model = TestModule()
...@@ -240,7 +238,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -240,7 +238,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir @tempdir
def test_attribute_preservation_qat(self, root_dir): def test_attribute_preservation_qat(self, root_dir):
""" Validates we can preserve specified properties in module. """ """Validates we can preserve specified properties in module."""
seed_everything(100) seed_everything(100)
model = TestModule() model = TestModule()
...@@ -274,7 +272,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -274,7 +272,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir @tempdir
def test_quantization_and_checkpointing(self, root_dir): def test_quantization_and_checkpointing(self, root_dir):
""" Validate written checkpoints can be loaded back as expected. """ """Validate written checkpoints can be loaded back as expected."""
seed_everything(100) seed_everything(100)
model = TestModule() model = TestModule()
...@@ -297,10 +295,10 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -297,10 +295,10 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir @tempdir
def test_custom_qat(self, root_dir): def test_custom_qat(self, root_dir):
""" Tests that we can customize QAT by skipping certain layers. """ """Tests that we can customize QAT by skipping certain layers."""
class _CustomQAT(QuantizationAwareTraining): class _CustomQAT(QuantizationAwareTraining):
""" Only quantize TestModule.another_layer. """ """Only quantize TestModule.another_layer."""
def prepare(self, model, configs, attrs): def prepare(self, model, configs, attrs):
model.another_layer = prepare_qat_fx(model.another_layer, configs[""]) model.another_layer = prepare_qat_fx(model.another_layer, configs[""])
...@@ -345,7 +343,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -345,7 +343,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir @tempdir
def test_submodule_qat(self, root_dir): def test_submodule_qat(self, root_dir):
""" Tests that we can customize QAT through exposed API. """ """Tests that we can customize QAT through exposed API."""
seed_everything(100) seed_everything(100)
model = TestModule() model = TestModule()
...@@ -386,7 +384,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -386,7 +384,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
class TestPostTrainingQuantization(unittest.TestCase): class TestPostTrainingQuantization(unittest.TestCase):
@tempdir @tempdir
def test_post_training_static_quantization(self, root_dir): def test_post_training_static_quantization(self, root_dir):
""" Validate post-training static quantization. """ """Validate post-training static quantization."""
seed_everything(100) seed_everything(100)
model = TestModule() model = TestModule()
...@@ -423,7 +421,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -423,7 +421,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
@tempdir @tempdir
def test_post_training_dynamic_quantization(self, root_dir): def test_post_training_dynamic_quantization(self, root_dir):
""" Validates post-training dynamic quantization. """ """Validates post-training dynamic quantization."""
seed_everything(100) seed_everything(100)
model = TestModule() model = TestModule()
...@@ -460,10 +458,10 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -460,10 +458,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
@tempdir @tempdir
def test_custom_post_training_static_quant(self, root_dir): def test_custom_post_training_static_quant(self, root_dir):
""" Tests that we can customize Post-Training static by skipping certain layers. """ """Tests that we can customize Post-Training static by skipping certain layers."""
class _CustomStaticQuant(PostTrainingQuantization): class _CustomStaticQuant(PostTrainingQuantization):
""" Only quantize TestModule.another_layer. """ """Only quantize TestModule.another_layer."""
def prepare(self, model, configs, attrs): def prepare(self, model, configs, attrs):
model.another_layer = prepare_fx(model.another_layer, configs[""]) model.another_layer = prepare_fx(model.another_layer, configs[""])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment