Unverified Commit c21e1071 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[deepspeed / m2m_100] make deepspeed zero-3 work with layerdrop (#16717)

* [deepspeed / m2m_100] make deepspeed 3 work with layerdrop

* fix

* revert last
parent 89293a0f
......@@ -24,6 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
......@@ -794,16 +795,21 @@ class M2M100Encoder(M2M100PreTrainedModel):
raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
......@@ -826,6 +832,9 @@ class M2M100Encoder(M2M100PreTrainedModel):
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
......@@ -1018,57 +1027,66 @@ class M2M100Decoder(M2M100PreTrainedModel):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
combined_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
combined_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if skip_the_layer:
continue
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment