Commit 0848c589 authored by Luis Perez's avatar Luis Perez Committed by Facebook GitHub Bot
Browse files

Synchronize PyTorchLightning/pytorch-lightning (revision 7b283e3c@master) to...

Synchronize PyTorchLightning/pytorch-lightning (revision 7b283e3c@master) to github/third-party/PyTorchLightning/pytorch-lightning

Summary:
# Manual
 - remove fixme's in `model_checkpoint.py`, `parameter_monitor.py`, `test_quantization.py`, and `speed_monitor.py` now that `Trainer` is properly annotated.
- update `test_quantization.py` to `trainer.train_loop.global_step` instead of `trainer.global_step` which is a read-only.
- update `loop_callback.py` to read from `train_loop` for `batch_idx` (which is no longer available).

# Automatic
### New commit log messages
  7b283e3c Bugfix/Multiple dataloaders (#7433)
  d7c44cc6 Docs: sync chlog 1.3.1 (#7478)
  fdf50a5e Mark certain Trainer APIs as protected (#7420)
  ad9118f0 remove trainer hidden state | sanity refactor [1 / n] (#7437)
  4a1134db Log epoch metrics before firing the `on_evaluation_end` hook (#7272)
  b65ae794 Automatically check `DataModule.has_{setup,teardown,prepare_data}` [2/2] (#7238)
  8660d8cf [pre-commit.ci] pre-commit autoupdate (#7475)
  f6fe715e Fix Sphinx argument deprecation (#7464)

Reviewed By: shuyingsunshine21

Differential Revision: D28353491

fbshipit-source-id: 98b87d99e2f09b47b07270858fcbdb5d5299730b
parent 5aa5e3c8
......@@ -20,20 +20,18 @@ from d2go.utils.testing.helper import tempdir
from d2go.utils.testing.lightning_test_module import TestModule
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from torch.quantization import ( # @manual; @manual
default_dynamic_qconfig,
get_default_qconfig,
)
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
class TestUtilities(unittest.TestCase):
""" Test some basic utilities we rely on. """
"""Test some basic utilities we rely on."""
def test_get_set_has(self):
""" Trivial test for generic behavior. Only support pre-existing deeply nested values."""
"""Trivial test for generic behavior. Only support pre-existing deeply nested values."""
class TestObject(object):
def __init__(self):
......@@ -54,10 +52,10 @@ class TestUtilities(unittest.TestCase):
class TestModelTransform(unittest.TestCase):
""" Tests ModelTransforms. """
"""Tests ModelTransforms."""
def test_invalid_construction_type_error(self):
""" Validate construction of ModelTransforms. Always have fn, msg, and one of [step, interval]. """
"""Validate construction of ModelTransforms. Always have fn, msg, and one of [step, interval]."""
with self.assertRaises(TypeError):
_ = ModelTransform()
with self.assertRaises(TypeError):
......@@ -73,7 +71,7 @@ class TestModelTransform(unittest.TestCase):
)
def test_positivity_value_error(self):
""" Validates ModelTransforms are constructed with only valid arguments. """
"""Validates ModelTransforms are constructed with only valid arguments."""
def identity(x):
return x
......@@ -88,7 +86,7 @@ class TestModelTransform(unittest.TestCase):
class TestQuantizationAwareTraining(unittest.TestCase):
def test_qat_misconfiguration(self):
""" Tests failure when misconfiguring the QAT Callback. """
"""Tests failure when misconfiguring the QAT Callback."""
invalid_params = [
{"start_step": -1},
{"enable_observer": (42, 42)},
......@@ -101,7 +99,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
_ = QuantizationAwareTraining(**invalid_param)
def test_qat_transforms(self):
""" Tests the appropropriate ModelTransforms are defined with QAT."""
"""Tests the appropropriate ModelTransforms are defined with QAT."""
qat = QuantizationAwareTraining(
start_step=300, enable_observer=(350, 500), freeze_bn_step=550
)
......@@ -129,7 +127,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
0,
f"step={step}",
)
trainer.global_step = step
trainer.train_loop.global_step = step
qat.on_train_batch_start(
trainer, module, batch=None, batch_idx=0, dataloader_idx=0
)
......@@ -153,7 +151,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir
def test_qat_interval_transform(self, root_dir):
""" Tests an interval transform is applied multiple times. """
"""Tests an interval transform is applied multiple times."""
seed_everything(100)
def linear_fn_counter(mod):
......@@ -182,7 +180,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir
def test_module_quantized_during_train(self, root_dir):
""" Validate quantized aware training works as expected. """
"""Validate quantized aware training works as expected."""
seed_everything(100)
model = TestModule()
......@@ -219,7 +217,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir
def test_quantization_without_train(self, root_dir):
""" Validate quantization occurs even without a call to .fit() first. """
"""Validate quantization occurs even without a call to .fit() first."""
seed_everything(100)
model = TestModule()
......@@ -240,7 +238,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir
def test_attribute_preservation_qat(self, root_dir):
""" Validates we can preserve specified properties in module. """
"""Validates we can preserve specified properties in module."""
seed_everything(100)
model = TestModule()
......@@ -274,7 +272,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir
def test_quantization_and_checkpointing(self, root_dir):
""" Validate written checkpoints can be loaded back as expected. """
"""Validate written checkpoints can be loaded back as expected."""
seed_everything(100)
model = TestModule()
......@@ -297,10 +295,10 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir
def test_custom_qat(self, root_dir):
""" Tests that we can customize QAT by skipping certain layers. """
"""Tests that we can customize QAT by skipping certain layers."""
class _CustomQAT(QuantizationAwareTraining):
""" Only quantize TestModule.another_layer. """
"""Only quantize TestModule.another_layer."""
def prepare(self, model, configs, attrs):
model.another_layer = prepare_qat_fx(model.another_layer, configs[""])
......@@ -345,7 +343,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
@tempdir
def test_submodule_qat(self, root_dir):
""" Tests that we can customize QAT through exposed API. """
"""Tests that we can customize QAT through exposed API."""
seed_everything(100)
model = TestModule()
......@@ -386,7 +384,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
class TestPostTrainingQuantization(unittest.TestCase):
@tempdir
def test_post_training_static_quantization(self, root_dir):
""" Validate post-training static quantization. """
"""Validate post-training static quantization."""
seed_everything(100)
model = TestModule()
......@@ -423,7 +421,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
@tempdir
def test_post_training_dynamic_quantization(self, root_dir):
""" Validates post-training dynamic quantization. """
"""Validates post-training dynamic quantization."""
seed_everything(100)
model = TestModule()
......@@ -460,10 +458,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
@tempdir
def test_custom_post_training_static_quant(self, root_dir):
""" Tests that we can customize Post-Training static by skipping certain layers. """
"""Tests that we can customize Post-Training static by skipping certain layers."""
class _CustomStaticQuant(PostTrainingQuantization):
""" Only quantize TestModule.another_layer. """
"""Only quantize TestModule.another_layer."""
def prepare(self, model, configs, attrs):
model.another_layer = prepare_fx(model.another_layer, configs[""])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment