Unverified Commit c21e1071 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[deepspeed / m2m_100] make deepspeed zero-3 work with layerdrop (#16717)

* [deepspeed / m2m_100] make deepspeed 3 work with layerdrop

* fix

* revert last
parent 89293a0f
...@@ -24,6 +24,7 @@ from torch import nn ...@@ -24,6 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -794,16 +795,21 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -794,16 +795,21 @@ class M2M100Encoder(M2M100PreTrainedModel):
raise ValueError( raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
) )
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
return module(*inputs, output_attentions) return module(*inputs, output_attentions)
...@@ -826,6 +832,9 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -826,6 +832,9 @@ class M2M100Encoder(M2M100PreTrainedModel):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
...@@ -1018,57 +1027,66 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1018,57 +1027,66 @@ class M2M100Decoder(M2M100PreTrainedModel):
raise ValueError( raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
) )
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: past_key_value = past_key_values[idx] if past_key_values is not None else None
if use_cache: if self.gradient_checkpointing and self.training:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
combined_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
) )
use_cache = False else:
def create_custom_forward(module): layer_outputs = decoder_layer(
def custom_forward(*inputs): hidden_states,
# None for past_key_value attention_mask=combined_attention_mask,
return module(*inputs, output_attentions, use_cache) encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
return custom_forward layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
layer_outputs = torch.utils.checkpoint.checkpoint( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
create_custom_forward(decoder_layer), ),
hidden_states, past_key_value=past_key_value,
combined_attention_mask, output_attentions=output_attentions,
encoder_hidden_states, use_cache=use_cache,
encoder_attention_mask, )
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, hidden_states = layer_outputs[0]
None,
) if skip_the_layer:
else: continue
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache: if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment