Commit f781223c authored by Ananth Subramaniam's avatar Ananth Subramaniam Committed by Facebook GitHub Bot
Browse files

Trainer(checkpoint_callback) -> Trainer(enable_checkpointing)

Reviewed By: kazhang

Differential Revision: D34669519

fbshipit-source-id: 8cfee968104c823a55960f2730d8e888ac1f298e
parent 82c6a50b
...@@ -168,7 +168,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -168,7 +168,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
) )
trainer = Trainer( trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"), default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False, enable_checkpointing=False,
callbacks=[qat], callbacks=[qat],
max_epochs=num_epochs, max_epochs=num_epochs,
logger=False, logger=False,
...@@ -192,7 +192,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -192,7 +192,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
accelerator="cpu", accelerator="cpu",
devices=1, devices=1,
default_root_dir=os.path.join(root_dir, "quantized"), default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False, enable_checkpointing=False,
callbacks=[qat], callbacks=[qat],
max_epochs=num_epochs, max_epochs=num_epochs,
logger=False, logger=False,
...@@ -227,7 +227,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -227,7 +227,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
qat = QuantizationAwareTraining() qat = QuantizationAwareTraining()
trainer = Trainer( trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"), default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False, enable_checkpointing=False,
callbacks=[qat], callbacks=[qat],
max_epochs=num_epochs, max_epochs=num_epochs,
logger=False, logger=False,
...@@ -254,7 +254,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -254,7 +254,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
) )
trainer = Trainer( trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"), default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False, enable_checkpointing=False,
callbacks=[qat], callbacks=[qat],
max_epochs=num_epochs, max_epochs=num_epochs,
logger=False, logger=False,
...@@ -318,7 +318,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -318,7 +318,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
qat = _CustomQAT() qat = _CustomQAT()
trainer = Trainer( trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"), default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False, enable_checkpointing=False,
callbacks=[qat], callbacks=[qat],
max_epochs=num_epochs, max_epochs=num_epochs,
logger=False, logger=False,
...@@ -357,7 +357,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -357,7 +357,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
) )
trainer = Trainer( trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"), default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False, enable_checkpointing=False,
callbacks=[qat], callbacks=[qat],
max_epochs=num_epochs, max_epochs=num_epochs,
logger=False, logger=False,
...@@ -396,7 +396,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -396,7 +396,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
) )
trainer = Trainer( trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"), default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False, enable_checkpointing=False,
callbacks=[static_quantization], callbacks=[static_quantization],
max_epochs=num_epochs, max_epochs=num_epochs,
logger=False, logger=False,
...@@ -433,7 +433,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -433,7 +433,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
) )
trainer = Trainer( trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"), default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False, enable_checkpointing=False,
callbacks=[dynamic_quant], callbacks=[dynamic_quant],
max_epochs=num_epochs, max_epochs=num_epochs,
logger=False, logger=False,
...@@ -480,7 +480,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -480,7 +480,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
static_quantization = _CustomStaticQuant() static_quantization = _CustomStaticQuant()
trainer = Trainer( trainer = Trainer(
default_root_dir=os.path.join(root_dir, "quantized"), default_root_dir=os.path.join(root_dir, "quantized"),
checkpoint_callback=False, enable_checkpointing=False,
callbacks=[static_quantization], callbacks=[static_quantization],
max_epochs=num_epochs, max_epochs=num_epochs,
logger=False, logger=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment