Unverified Commit 27d46397 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make gradient_checkpointing a training argument (#13657)



* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 75f6641e
...@@ -57,8 +57,6 @@ class ViTConfig(PretrainedConfig): ...@@ -57,8 +57,6 @@ class ViTConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
image_size (:obj:`int`, `optional`, defaults to :obj:`224`): image_size (:obj:`int`, `optional`, defaults to :obj:`224`):
The size (resolution) of each image. The size (resolution) of each image.
patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): patch_size (:obj:`int`, `optional`, defaults to :obj:`16`):
......
...@@ -352,6 +352,7 @@ class ViTEncoder(nn.Module): ...@@ -352,6 +352,7 @@ class ViTEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -370,7 +371,7 @@ class ViTEncoder(nn.Module): ...@@ -370,7 +371,7 @@ class ViTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -411,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel): ...@@ -411,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel):
config_class = ViTConfig config_class = ViTConfig
base_model_prefix = "vit" base_model_prefix = "vit"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -428,6 +430,10 @@ class ViTPreTrainedModel(PreTrainedModel): ...@@ -428,6 +430,10 @@ class ViTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ViTEncoder):
module.gradient_checkpointing = value
VIT_START_DOCSTRING = r""" VIT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
......
...@@ -138,8 +138,6 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -138,8 +138,6 @@ class Wav2Vec2Config(PretrainedConfig):
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`. instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256): classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification. Dimensionality of the projection before token mean-pooling for classification.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example:: Example::
...@@ -198,7 +196,6 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -198,7 +196,6 @@ class Wav2Vec2Config(PretrainedConfig):
ctc_zero_infinity=False, ctc_zero_infinity=False,
use_weighted_layer_sum=False, use_weighted_layer_sum=False,
classifier_proj_size=256, classifier_proj_size=256,
gradient_checkpointing=False,
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
...@@ -229,7 +226,6 @@ class Wav2Vec2Config(PretrainedConfig): ...@@ -229,7 +226,6 @@ class Wav2Vec2Config(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm self.do_stable_layer_norm = do_stable_layer_norm
self.gradient_checkpointing = gradient_checkpointing
self.use_weighted_layer_sum = use_weighted_layer_sum self.use_weighted_layer_sum = use_weighted_layer_sum
self.classifier_proj_size = classifier_proj_size self.classifier_proj_size = classifier_proj_size
......
...@@ -590,6 +590,7 @@ class Wav2Vec2Encoder(nn.Module): ...@@ -590,6 +590,7 @@ class Wav2Vec2Encoder(nn.Module):
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -629,7 +630,7 @@ class Wav2Vec2Encoder(nn.Module): ...@@ -629,7 +630,7 @@ class Wav2Vec2Encoder(nn.Module):
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function # create gradient checkpointing function
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -676,6 +677,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): ...@@ -676,6 +677,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
) )
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -715,7 +717,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): ...@@ -715,7 +717,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function # create gradient checkpointing function
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -842,6 +844,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): ...@@ -842,6 +844,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
config_class = Wav2Vec2Config config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2" base_model_prefix = "wav2vec2"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -864,6 +867,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): ...@@ -864,6 +867,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm)):
module.gradient_checkpointing = value
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
""" """
Computes the output length of the convolutional layers Computes the output length of the convolutional layers
......
...@@ -990,7 +990,7 @@ class Trainer: ...@@ -990,7 +990,7 @@ class Trainer:
elif isinstance(model, PreTrainedModel): elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per # find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False) find_unused_parameters = not getattr(model.config, "_gradient_checkpointing", False)
else: else:
find_unused_parameters = True find_unused_parameters = True
model = nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
...@@ -1162,6 +1162,10 @@ class Trainer: ...@@ -1162,6 +1162,10 @@ class Trainer:
self.state = TrainerState() self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None self.state.is_hyper_param_search = trial is not None
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
model = self._wrap_model(self.model_wrapped) model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not # for the rest of this function `model` is the outside model, whether it was wrapped or not
......
...@@ -372,6 +372,8 @@ class TrainingArguments: ...@@ -372,6 +372,8 @@ class TrainingArguments:
hub_token (:obj:`str`, `optional`): hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`. :obj:`huggingface-cli login`.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -650,6 +652,12 @@ class TrainingArguments: ...@@ -650,6 +652,12 @@ class TrainingArguments:
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
) )
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
# Deprecated arguments # Deprecated arguments
push_to_hub_model_id: str = field( push_to_hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
......
...@@ -72,8 +72,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig): ...@@ -72,8 +72,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``. relevant if ``config.is_decoder=True``.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
{% else -%} {% else -%}
vocab_size (:obj:`int`, `optional`, defaults to 50265): vocab_size (:obj:`int`, `optional`, defaults to 50265):
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
...@@ -186,7 +184,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig): ...@@ -186,7 +184,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
decoder_start_token_id=2, decoder_start_token_id=2,
classifier_dropout=0.0, classifier_dropout=0.0,
scale_embedding=False, scale_embedding=False,
gradient_checkpointing=False,
{% endif -%} {% endif -%}
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
...@@ -225,7 +222,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig): ...@@ -225,7 +222,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
self.use_cache = use_cache self.use_cache = use_cache
self.num_hidden_layers = encoder_layers self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
{% endif -%} {% endif -%}
......
...@@ -513,6 +513,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): ...@@ -513,6 +513,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([{{cookiecutter.camelcase_modelname}}Layer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([{{cookiecutter.camelcase_modelname}}Layer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -539,12 +540,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): ...@@ -539,12 +540,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
...@@ -664,6 +664,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): ...@@ -664,6 +664,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config config_class = {{cookiecutter.camelcase_modelname}}Config
load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}} load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}}
base_model_prefix = "{{cookiecutter.lowercase_modelname}}" base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -682,6 +683,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): ...@@ -682,6 +683,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder):
module.gradient_checkpointing = value
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
...@@ -2006,6 +2011,7 @@ class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module): ...@@ -2006,6 +2011,7 @@ class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config config_class = {{cookiecutter.camelcase_modelname}}Config
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -2017,16 +2023,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): ...@@ -2017,16 +2023,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
@property def _set_gradient_checkpointing(self, module, value=False):
def dummy_inputs(self): if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)):
pad_token = self.config.pad_token_id module.gradient_checkpointing = value
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
}
return dummy_inputs
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
...@@ -2213,6 +2213,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model ...@@ -2213,6 +2213,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -2309,7 +2310,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model ...@@ -2309,7 +2310,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -2376,6 +2377,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2376,6 +2377,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
self.layernorm_embedding = nn.LayerNorm(config.d_model) self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
...@@ -2545,10 +2547,10 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2545,10 +2547,10 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...") logger.warning("`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...")
use_cache = False use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
......
...@@ -224,6 +224,27 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -224,6 +224,27 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training:
return
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
# we don't test BeitForMaskedImageModeling
if model_class.__name__ == "BeitForMaskedImageModeling":
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -370,15 +370,14 @@ class ModelTesterMixin: ...@@ -370,15 +370,14 @@ class ModelTesterMixin:
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training or not hasattr(config, "gradient_checkpointing"): if not self.model_tester.is_training:
return return
config.gradient_checkpointing = True
config.use_cache = False config.use_cache = False
config.return_dict = True config.return_dict = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING): if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue continue
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
......
...@@ -20,6 +20,7 @@ import unittest ...@@ -20,6 +20,7 @@ import unittest
from transformers import DeiTConfig from transformers import DeiTConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -340,7 +341,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -340,7 +341,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
# DeiTForImageClassificationWithTeacher supports inference-only # DeiTForImageClassificationWithTeacher supports inference-only
if ( if (
model_class in MODEL_MAPPING.values() model_class in get_values(MODEL_MAPPING)
or model_class.__name__ == "DeiTForImageClassificationWithTeacher" or model_class.__name__ == "DeiTForImageClassificationWithTeacher"
): ):
continue continue
...@@ -351,6 +352,27 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -351,6 +352,27 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training:
return
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
# DeiTForImageClassificationWithTeacher supports inference-only
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_for_image_classification(self): def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs) self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
......
...@@ -82,7 +82,7 @@ class FlaxGPT2ModelTester: ...@@ -82,7 +82,7 @@ class FlaxGPT2ModelTester:
self.eos_token_id = vocab_size - 1 self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self, gradient_checkpointing=False): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None input_mask = None
...@@ -100,7 +100,6 @@ class FlaxGPT2ModelTester: ...@@ -100,7 +100,6 @@ class FlaxGPT2ModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
) )
return (config, input_ids, input_mask) return (config, input_ids, input_mask)
......
...@@ -86,7 +86,7 @@ class FlaxGPTNeoModelTester: ...@@ -86,7 +86,7 @@ class FlaxGPTNeoModelTester:
self.eos_token_id = vocab_size - 1 self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self, gradient_checkpointing=False): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None input_mask = None
...@@ -105,7 +105,6 @@ class FlaxGPTNeoModelTester: ...@@ -105,7 +105,6 @@ class FlaxGPTNeoModelTester:
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
window_size=self.window_size, window_size=self.window_size,
attention_types=self.attention_types, attention_types=self.attention_types,
gradient_checkpointing=gradient_checkpointing,
) )
return (config, input_ids, input_mask) return (config, input_ids, input_mask)
......
...@@ -96,7 +96,7 @@ class GPT2ModelTester: ...@@ -96,7 +96,7 @@ class GPT2ModelTester:
def get_large_model_config(self): def get_large_model_config(self):
return GPT2Config.from_pretrained("gpt2") return GPT2Config.from_pretrained("gpt2")
def prepare_config_and_inputs(self, gradient_checkpointing=False): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None input_mask = None
...@@ -119,7 +119,7 @@ class GPT2ModelTester: ...@@ -119,7 +119,7 @@ class GPT2ModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(gradient_checkpointing=gradient_checkpointing) config = self.get_config()
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
...@@ -135,7 +135,7 @@ class GPT2ModelTester: ...@@ -135,7 +135,7 @@ class GPT2ModelTester:
choice_labels, choice_labels,
) )
def get_config(self, gradient_checkpointing=False): def get_config(self):
return GPT2Config( return GPT2Config(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
n_embd=self.hidden_size, n_embd=self.hidden_size,
...@@ -149,11 +149,10 @@ class GPT2ModelTester: ...@@ -149,11 +149,10 @@ class GPT2ModelTester:
n_ctx=self.max_position_embeddings, n_ctx=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing, use_cache=True,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
) )
def prepare_config_and_inputs_for_decoder(self): def prepare_config_and_inputs_for_decoder(self):
...@@ -322,9 +321,13 @@ class GPT2ModelTester: ...@@ -322,9 +321,13 @@ class GPT2ModelTester:
self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = GPT2LMHeadModel(config) model = GPT2LMHeadModel(config)
model.to(torch_device) model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.loss.shape, ())
...@@ -478,8 +481,8 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -478,8 +481,8 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_gpt2_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_gpt2_for_token_classification(*config_and_inputs)
def test_gpt2_gradient_checkpointing(self): def test_gpt2_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
@slow @slow
def test_batch_generation(self): def test_batch_generation(self):
...@@ -612,7 +615,11 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -612,7 +615,11 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow @slow
def test_lm_generate_gpt2(self): def test_lm_generate_gpt2(self):
for checkpointing in [True, False]: for checkpointing in [True, False]:
model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing) model = GPT2LMHeadModel.from_pretrained("gpt2")
if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [ expected_output_ids = [
......
...@@ -97,7 +97,7 @@ class GPTNeoModelTester: ...@@ -97,7 +97,7 @@ class GPTNeoModelTester:
def get_large_model_config(self): def get_large_model_config(self):
return GPTNeoConfig.from_pretrained("gpt_neo") return GPTNeoConfig.from_pretrained("gpt_neo")
def prepare_config_and_inputs(self, gradient_checkpointing=False): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None input_mask = None
...@@ -120,7 +120,7 @@ class GPTNeoModelTester: ...@@ -120,7 +120,7 @@ class GPTNeoModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(gradient_checkpointing=False) config = self.get_config()
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
...@@ -136,18 +136,17 @@ class GPTNeoModelTester: ...@@ -136,18 +136,17 @@ class GPTNeoModelTester:
choice_labels, choice_labels,
) )
def get_config(self, gradient_checkpointing=False): def get_config(self):
return GPTNeoConfig( return GPTNeoConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_layers=self.num_hidden_layers, num_layers=self.num_hidden_layers,
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
use_cache=not gradient_checkpointing, use_cache=True,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
window_size=self.window_size, window_size=self.window_size,
attention_types=self.attention_types, attention_types=self.attention_types,
) )
...@@ -329,8 +328,12 @@ class GPTNeoModelTester: ...@@ -329,8 +328,12 @@ class GPTNeoModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = GPTNeoForCausalLM(config) model = GPTNeoForCausalLM(config)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
model.to(torch_device) model.to(torch_device)
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
...@@ -411,8 +414,8 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ...@@ -411,8 +414,8 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)
def test_gpt_neo_gradient_checkpointing(self): def test_gpt_neo_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
def _get_hidden_states(self): def _get_hidden_states(self):
return torch.tensor( return torch.tensor(
...@@ -473,7 +476,10 @@ class GPTNeoModelLanguageGenerationTest(unittest.TestCase): ...@@ -473,7 +476,10 @@ class GPTNeoModelLanguageGenerationTest(unittest.TestCase):
def test_lm_generate_gpt_neo(self): def test_lm_generate_gpt_neo(self):
for checkpointing in [True, False]: for checkpointing in [True, False]:
model = self.model model = self.model
model.config.gradient_checkpointing = checkpointing if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
# fmt: off # fmt: off
# The dog-eared copy of the book, which is a collection of essays by the late author, # The dog-eared copy of the book, which is a collection of essays by the late author,
......
...@@ -92,7 +92,7 @@ class GPTJModelTester: ...@@ -92,7 +92,7 @@ class GPTJModelTester:
def get_large_model_config(self): def get_large_model_config(self):
return GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B") return GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
def prepare_config_and_inputs(self, gradient_checkpointing=False): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None input_mask = None
...@@ -115,7 +115,7 @@ class GPTJModelTester: ...@@ -115,7 +115,7 @@ class GPTJModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(gradient_checkpointing=gradient_checkpointing) config = self.get_config()
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
...@@ -131,7 +131,7 @@ class GPTJModelTester: ...@@ -131,7 +131,7 @@ class GPTJModelTester:
choice_labels, choice_labels,
) )
def get_config(self, gradient_checkpointing=False): def get_config(self):
return GPTJConfig( return GPTJConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
n_embd=self.hidden_size, n_embd=self.hidden_size,
...@@ -145,11 +145,10 @@ class GPTJModelTester: ...@@ -145,11 +145,10 @@ class GPTJModelTester:
n_ctx=self.max_position_embeddings, n_ctx=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing, use_cache=True,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
) )
def prepare_config_and_inputs_for_decoder(self): def prepare_config_and_inputs_for_decoder(self):
...@@ -318,8 +317,12 @@ class GPTJModelTester: ...@@ -318,8 +317,12 @@ class GPTJModelTester:
self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = GPTJForCausalLM(config) model = GPTJForCausalLM(config)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
model.to(torch_device) model.to(torch_device)
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
...@@ -390,8 +393,8 @@ class GPTJModelTest(unittest.TestCase): ...@@ -390,8 +393,8 @@ class GPTJModelTest(unittest.TestCase):
self.model_tester.create_and_check_lm_head_model(*config_and_inputs) self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
def test_gptj_gradient_checkpointing(self): def test_gptj_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
@slow @slow
def test_batch_generation(self): def test_batch_generation(self):
...@@ -464,7 +467,11 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase): ...@@ -464,7 +467,11 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase):
@slow @slow
def test_lm_generate_gptj(self): def test_lm_generate_gptj(self):
for checkpointing in [True, False]: for checkpointing in [True, False]:
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", gradient_checkpointing=checkpointing) model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [ expected_output_ids = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment