Unverified Commit 27d46397 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make gradient_checkpointing a training argument (#13657)



* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 75f6641e
......@@ -57,8 +57,6 @@ class ViTConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
image_size (:obj:`int`, `optional`, defaults to :obj:`224`):
The size (resolution) of each image.
patch_size (:obj:`int`, `optional`, defaults to :obj:`16`):
......
......@@ -352,6 +352,7 @@ class ViTEncoder(nn.Module):
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -370,7 +371,7 @@ class ViTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -411,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel):
config_class = ViTConfig
base_model_prefix = "vit"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -428,6 +430,10 @@ class ViTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ViTEncoder):
module.gradient_checkpointing = value
VIT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
......
......@@ -138,8 +138,6 @@ class Wav2Vec2Config(PretrainedConfig):
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
......@@ -198,7 +196,6 @@ class Wav2Vec2Config(PretrainedConfig):
ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
......@@ -229,7 +226,6 @@ class Wav2Vec2Config(PretrainedConfig):
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.gradient_checkpointing = gradient_checkpointing
self.use_weighted_layer_sum = use_weighted_layer_sum
self.classifier_proj_size = classifier_proj_size
......
......@@ -590,6 +590,7 @@ class Wav2Vec2Encoder(nn.Module):
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -629,7 +630,7 @@ class Wav2Vec2Encoder(nn.Module):
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -676,6 +677,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
self.layers = nn.ModuleList(
[Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward(
self,
......@@ -715,7 +717,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -842,6 +844,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -864,6 +867,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm)):
module.gradient_checkpointing = value
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the convolutional layers
......
......@@ -990,7 +990,7 @@ class Trainer:
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
find_unused_parameters = not getattr(model.config, "_gradient_checkpointing", False)
else:
find_unused_parameters = True
model = nn.parallel.DistributedDataParallel(
......@@ -1162,6 +1162,10 @@ class Trainer:
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
......
......@@ -372,6 +372,8 @@ class TrainingArguments:
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
"""
output_dir: str = field(
......@@ -650,6 +652,12 @@ class TrainingArguments:
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
# Deprecated arguments
push_to_hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
......
......@@ -72,8 +72,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
{% else -%}
vocab_size (:obj:`int`, `optional`, defaults to 50265):
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
......@@ -186,7 +184,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
decoder_start_token_id=2,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
{% endif -%}
pad_token_id=1,
bos_token_id=0,
......@@ -225,7 +222,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
{% endif -%}
......
......@@ -513,6 +513,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
super().__init__()
self.config = config
self.layer = nn.ModuleList([{{cookiecutter.camelcase_modelname}}Layer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -539,12 +540,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......@@ -664,6 +664,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config
load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}}
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -682,6 +683,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder):
module.gradient_checkpointing = value
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
......@@ -2006,6 +2011,7 @@ class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
......@@ -2017,16 +2023,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
}
return dummy_inputs
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)):
module.gradient_checkpointing = value
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
......@@ -2213,6 +2213,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
......@@ -2309,7 +2310,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -2376,6 +2377,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.embed_tokens
......@@ -2545,10 +2547,10 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...")
logger.warning("`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...")
use_cache = False
def create_custom_forward(module):
......
......@@ -224,6 +224,27 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss
loss.backward()
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training:
return
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
# we don't test BeitForMaskedImageModeling
if model_class.__name__ == "BeitForMaskedImageModeling":
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -370,15 +370,14 @@ class ModelTesterMixin:
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training or not hasattr(config, "gradient_checkpointing"):
if not self.model_tester.is_training:
return
config.gradient_checkpointing = True
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING):
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
model = model_class(config)
model.to(torch_device)
......
......@@ -20,6 +20,7 @@ import unittest
from transformers import DeiTConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from .test_configuration_common import ConfigTester
......@@ -340,7 +341,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
# DeiTForImageClassificationWithTeacher supports inference-only
if (
model_class in MODEL_MAPPING.values()
model_class in get_values(MODEL_MAPPING)
or model_class.__name__ == "DeiTForImageClassificationWithTeacher"
):
continue
......@@ -351,6 +352,27 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss
loss.backward()
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training:
return
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
# DeiTForImageClassificationWithTeacher supports inference-only
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
......
......@@ -82,7 +82,7 @@ class FlaxGPT2ModelTester:
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self, gradient_checkpointing=False):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
......@@ -100,7 +100,6 @@ class FlaxGPT2ModelTester:
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
return (config, input_ids, input_mask)
......
......@@ -86,7 +86,7 @@ class FlaxGPTNeoModelTester:
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self, gradient_checkpointing=False):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
......@@ -105,7 +105,6 @@ class FlaxGPTNeoModelTester:
pad_token_id=self.pad_token_id,
window_size=self.window_size,
attention_types=self.attention_types,
gradient_checkpointing=gradient_checkpointing,
)
return (config, input_ids, input_mask)
......
......@@ -96,7 +96,7 @@ class GPT2ModelTester:
def get_large_model_config(self):
return GPT2Config.from_pretrained("gpt2")
def prepare_config_and_inputs(self, gradient_checkpointing=False):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
......@@ -119,7 +119,7 @@ class GPT2ModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(gradient_checkpointing=gradient_checkpointing)
config = self.get_config()
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......@@ -135,7 +135,7 @@ class GPT2ModelTester:
choice_labels,
)
def get_config(self, gradient_checkpointing=False):
def get_config(self):
return GPT2Config(
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
......@@ -149,11 +149,10 @@ class GPT2ModelTester:
n_ctx=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
def prepare_config_and_inputs_for_decoder(self):
......@@ -322,9 +321,13 @@ class GPT2ModelTester:
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = GPT2LMHeadModel(config)
model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
......@@ -478,8 +481,8 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_gpt2_for_token_classification(*config_and_inputs)
def test_gpt2_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
@slow
def test_batch_generation(self):
......@@ -612,7 +615,11 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gpt2(self):
for checkpointing in [True, False]:
model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing)
model = GPT2LMHeadModel.from_pretrained("gpt2")
if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [
......
......@@ -97,7 +97,7 @@ class GPTNeoModelTester:
def get_large_model_config(self):
return GPTNeoConfig.from_pretrained("gpt_neo")
def prepare_config_and_inputs(self, gradient_checkpointing=False):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
......@@ -120,7 +120,7 @@ class GPTNeoModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(gradient_checkpointing=False)
config = self.get_config()
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......@@ -136,18 +136,17 @@ class GPTNeoModelTester:
choice_labels,
)
def get_config(self, gradient_checkpointing=False):
def get_config(self):
return GPTNeoConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_layers=self.num_hidden_layers,
num_heads=self.num_attention_heads,
max_position_embeddings=self.max_position_embeddings,
use_cache=not gradient_checkpointing,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
window_size=self.window_size,
attention_types=self.attention_types,
)
......@@ -329,8 +328,12 @@ class GPTNeoModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = GPTNeoForCausalLM(config)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
model.to(torch_device)
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
......@@ -411,8 +414,8 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)
def test_gpt_neo_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
def _get_hidden_states(self):
return torch.tensor(
......@@ -473,7 +476,10 @@ class GPTNeoModelLanguageGenerationTest(unittest.TestCase):
def test_lm_generate_gpt_neo(self):
for checkpointing in [True, False]:
model = self.model
model.config.gradient_checkpointing = checkpointing
if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
# fmt: off
# The dog-eared copy of the book, which is a collection of essays by the late author,
......
......@@ -92,7 +92,7 @@ class GPTJModelTester:
def get_large_model_config(self):
return GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
def prepare_config_and_inputs(self, gradient_checkpointing=False):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
......@@ -115,7 +115,7 @@ class GPTJModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(gradient_checkpointing=gradient_checkpointing)
config = self.get_config()
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......@@ -131,7 +131,7 @@ class GPTJModelTester:
choice_labels,
)
def get_config(self, gradient_checkpointing=False):
def get_config(self):
return GPTJConfig(
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
......@@ -145,11 +145,10 @@ class GPTJModelTester:
n_ctx=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
def prepare_config_and_inputs_for_decoder(self):
......@@ -318,8 +317,12 @@ class GPTJModelTester:
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = GPTJForCausalLM(config)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
model.to(torch_device)
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
......@@ -390,8 +393,8 @@ class GPTJModelTest(unittest.TestCase):
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
def test_gptj_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
@slow
def test_batch_generation(self):
......@@ -464,7 +467,11 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gptj(self):
for checkpointing in [True, False]:
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", gradient_checkpointing=checkpointing)
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment