Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
......@@ -605,20 +605,15 @@ class EsmEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -710,9 +705,10 @@ class EsmPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, EsmEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
ESM_START_DOCSTRING = r"""
......
......@@ -1097,9 +1097,10 @@ class FalconPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
# Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None):
if isinstance(module, FalconModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@staticmethod
def _convert_cache_to_standard_format(
......@@ -1278,21 +1279,16 @@ class FalconModel(FalconPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self.gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
attention_mask,
position_ids,
head_mask[i],
layer_past,
use_cache,
output_attentions,
)
else:
outputs = block(
......
......@@ -663,18 +663,12 @@ class FlavaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
......@@ -879,9 +873,10 @@ class FlavaPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: FlavaEncoder, value: bool = False) -> None:
def _set_gradient_checkpointing(self, module: FlavaEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, FlavaEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings(
......
......@@ -292,14 +292,7 @@ class FNetEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), hidden_states)
layer_outputs = self.gradient_checkpointing_func(layer_module.__call__, hidden_states)
else:
layer_outputs = layer_module(hidden_states)
......@@ -431,9 +424,10 @@ class FNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, FNetEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
......
......@@ -586,15 +586,8 @@ class FocalNetEncoder(nn.Module):
for i, stage_module in enumerate(self.stages):
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
stage_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(stage_module),
stage_outputs = self.gradient_checkpointing_func(
stage_module.__call__,
hidden_states,
input_dimensions,
)
......@@ -659,9 +652,10 @@ class FocalNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, FocalNetEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
FOCALNET_START_DOCSTRING = r"""
......
......@@ -70,9 +70,10 @@ class FuyuPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, FuyuForCausalLM):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
FUYU_INPUTS_DOCSTRING = r"""
......
......@@ -452,18 +452,13 @@ class GitEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -533,9 +528,10 @@ class GitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (GitEncoder, GitVisionEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GIT_START_DOCSTRING = r"""
......@@ -878,18 +874,12 @@ class GitVisionEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......
......@@ -480,9 +480,10 @@ class GPT2PreTrainedModel(PreTrainedModel):
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPT2Model):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
......@@ -877,22 +878,16 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self.gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
)
else:
outputs = block(
......@@ -1623,7 +1618,6 @@ class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
# Model parallel
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
......
......@@ -405,9 +405,10 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTBigCodeModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GPT_BIGCODE_START_DOCSTRING = r"""
......@@ -650,22 +651,16 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self.gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
)
else:
outputs = block(
......
......@@ -384,9 +384,10 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTNeoModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GPT_NEO_START_DOCSTRING = r"""
......@@ -604,20 +605,14 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self.gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
......
......@@ -78,9 +78,10 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTNeoXModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class GPTNeoXAttention(nn.Module):
......@@ -641,20 +642,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for layer_past
return module(*inputs, use_cache, None, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
position_ids,
head_mask[i],
use_cache,
None,
output_attentions,
)
else:
outputs = layer(
......
......@@ -66,9 +66,10 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTNeoXJapaneseModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class GPTNeoXJapaneseAttention(nn.Module):
......
......@@ -363,9 +363,10 @@ class GPTJPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTJModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GPTJ_START_DOCSTRING = r"""
......@@ -669,21 +670,15 @@ class GPTJModel(GPTJPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self.gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
position_ids,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
......
......@@ -759,9 +759,10 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (GPTSanJapaneseAttention,)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
def _shift_right(self, input_ids):
......
......@@ -772,9 +772,10 @@ class GraphormerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GraphormerModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class GraphormerModel(GraphormerPreTrainedModel):
......
......@@ -492,7 +492,6 @@ class GroupViTStage(nn.Module):
self.group_token = nn.Parameter(torch.zeros(1, num_group_token, config.hidden_size))
else:
self.group_token = None
self.gradient_checkpointing = False
self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(depth)])
if num_group_token > 0:
......@@ -805,9 +804,10 @@ class GroupViTPreTrainedModel(PreTrainedModel):
nn.init.normal_(module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (GroupViTTextEncoder, GroupViTVisionEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GROUPVIT_START_DOCSTRING = r"""
......@@ -1031,18 +1031,12 @@ class GroupViTTextEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......
......@@ -346,15 +346,8 @@ class HubertFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states = self.gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
......@@ -731,17 +724,11 @@ class HubertEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
layer_outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
......@@ -821,17 +808,11 @@ class HubertEncoderStableLayerNorm(nn.Module):
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
layer_outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
......@@ -895,9 +876,10 @@ class HubertPreTrainedModel(PreTrainedModel):
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (HubertFeatureEncoder, HubertEncoder, HubertEncoderStableLayerNorm)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
......
......@@ -40,7 +40,7 @@ from ...utils import (
)
from .configuration_idefics import IdeficsConfig
from .perceiver import IdeficsPerceiverResampler
from .vision import IdeficsVisionTransformer
from .vision import IdeficsVisionEncoder, IdeficsVisionTransformer
logger = logging.get_logger(__name__)
......@@ -978,9 +978,10 @@ class IdeficsPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, IdeficsModel):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (IdeficsModel, IdeficsVisionEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LLAMA_INPUTS_DOCSTRING = r"""
......@@ -1098,7 +1099,6 @@ class IdeficsModel(IdeficsPreTrainedModel):
self.norm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
......@@ -1339,7 +1339,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
)
use_cache = False
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = self.gradient_checkpointing_func(
vblock,
decoder_layer,
hidden_states,
......
......@@ -401,18 +401,12 @@ class IdeficsVisionEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......
......@@ -525,9 +525,10 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ImageGPTModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
IMAGEGPT_START_DOCSTRING = r"""
......@@ -816,22 +817,16 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self.gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
)
else:
outputs = block(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment