Unverified Commit f7ea959b authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`/ `GC` / `tests`] Stronger GC tests (#27124)



* stronger GC tests

* better tests and skip failing tests

* break down into 3 sub-tests

* break down into 3 sub-tests

* refactor a bit

* more refactor

* fix

* last nit

* credits contrib and suggestions

* credits contrib and suggestions

---------
Co-authored-by: default avatarYih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 5bbf6712
...@@ -246,6 +246,18 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.T ...@@ -246,6 +246,18 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.T
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
def test_generate_fp16(self): def test_generate_fp16(self):
pass pass
......
...@@ -702,6 +702,18 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase): ...@@ -702,6 +702,18 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
# overwrite from test_modeling_common # overwrite from test_modeling_common
def _mock_init_weights(self, module): def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
...@@ -987,6 +999,18 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase): ...@@ -987,6 +999,18 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
# overwrite from test_modeling_common # overwrite from test_modeling_common
def _mock_init_weights(self, module): def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
...@@ -1421,6 +1445,18 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase): ...@@ -1421,6 +1445,18 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
# overwrite from test_modeling_common # overwrite from test_modeling_common
def _mock_init_weights(self, module): def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
......
...@@ -207,6 +207,18 @@ class Swin2SRModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -207,6 +207,18 @@ class Swin2SRModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -369,6 +369,24 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unit ...@@ -369,6 +369,24 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unit
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length], [self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
) )
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@parameterized.expand( @parameterized.expand(
[ [
(1, 5, [1]), (1, 5, [1]),
......
...@@ -537,6 +537,24 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -537,6 +537,24 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_disk_offload(self): def test_disk_offload(self):
pass pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
......
...@@ -320,6 +320,18 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -320,6 +320,18 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip( @unittest.skip(
reason="""VilT samples image tokens from a multinomial distribution, resulting in not deterministic reason="""VilT samples image tokens from a multinomial distribution, resulting in not deterministic
hidden states""" hidden states"""
......
...@@ -555,6 +555,24 @@ class VisualBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -555,6 +555,24 @@ class VisualBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
model = VisualBertModel.from_pretrained(model_name) model = VisualBertModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@require_torch @require_torch
class VisualBertModelIntegrationTest(unittest.TestCase): class VisualBertModelIntegrationTest(unittest.TestCase):
......
...@@ -173,6 +173,18 @@ class VitMatteModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase ...@@ -173,6 +173,18 @@ class VitMatteModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="ViTMatte does not support input and output embeddings") @unittest.skip(reason="ViTMatte does not support input and output embeddings")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
......
...@@ -414,6 +414,18 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -414,6 +414,18 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
def test_generate_with_head_masking(self): def test_generate_with_head_masking(self):
pass pass
......
...@@ -194,6 +194,18 @@ class XCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -194,6 +194,18 @@ class XCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="XCLIPVisionModel has no base class and is not available in MODEL_MAPPING") @unittest.skip(reason="XCLIPVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
...@@ -416,6 +428,18 @@ class XCLIPTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -416,6 +428,18 @@ class XCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="X-CLIP does not use inputs_embeds") @unittest.skip(reason="X-CLIP does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -539,56 +539,78 @@ class ModelTesterMixin: ...@@ -539,56 +539,78 @@ class ModelTesterMixin:
expected_arg_names = ["input_ids"] expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names) self.assertListEqual(arg_names[:1], expected_arg_names)
def test_training(self): def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
if not self.model_tester.is_training: if not self.model_tester.is_training:
return return
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True config.return_dict = True
if model_class.__name__ in [ if (
*get_values(MODEL_MAPPING_NAMES), model_class.__name__
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES), in [*get_values(MODEL_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)]
]: or not model_class.supports_gradient_checkpointing
):
continue continue
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model.train() model.train()
# unfreeze additional layers
for p in model.parameters():
p.requires_grad_(True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
optimizer.step()
def test_training_gradient_checkpointing(self): for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
def test_training(self):
if not self.model_tester.is_training: if not self.model_tester.is_training:
return return
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True config.return_dict = True
if ( if model_class.__name__ in [
model_class.__name__ *get_values(MODEL_MAPPING_NAMES),
in [*get_values(MODEL_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
or not model_class.supports_gradient_checkpointing ]:
):
continue continue
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.gradient_checkpointing_enable()
model.train() model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
model.gradient_checkpointing_disable() def test_training_gradient_checkpointing(self):
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) # Scenario - 1 default behaviour
model.train() self.check_training_gradient_checkpointing()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss def test_training_gradient_checkpointing_use_reentrant(self):
loss.backward() # Scenario - 2 with `use_reentrant=True` - this is the default value that is used in pytorch's
# torch.utils.checkpoint.checkpoint
self.check_training_gradient_checkpointing(gradient_checkpointing_kwargs={"use_reentrant": True})
def test_training_gradient_checkpointing_use_reentrant_false(self):
# Scenario - 3 with `use_reentrant=False` pytorch suggests users to use this value for
# future releases: https://pytorch.org/docs/stable/checkpoint.html
self.check_training_gradient_checkpointing(gradient_checkpointing_kwargs={"use_reentrant": False})
def test_attention_outputs(self): def test_attention_outputs(self):
if not self.has_attentions: if not self.has_attentions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment