Unverified Commit c21e1071 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[deepspeed / m2m_100] make deepspeed zero-3 work with layerdrop (#16717)

* [deepspeed / m2m_100] make deepspeed 3 work with layerdrop

* fix

* revert last
parent 89293a0f
...@@ -24,6 +24,7 @@ from torch import nn ...@@ -24,6 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -794,16 +795,21 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -794,16 +795,21 @@ class M2M100Encoder(M2M100PreTrainedModel):
raise ValueError( raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
) )
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
return module(*inputs, output_attentions) return module(*inputs, output_attentions)
...@@ -826,6 +832,9 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -826,6 +832,9 @@ class M2M100Encoder(M2M100PreTrainedModel):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
...@@ -1018,13 +1027,18 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1018,13 +1027,18 @@ class M2M100Decoder(M2M100PreTrainedModel):
raise ValueError( raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
) )
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
...@@ -1068,8 +1082,12 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1068,8 +1082,12 @@ class M2M100Decoder(M2M100PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if skip_the_layer:
continue
if use_cache: if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment