Unverified Commit 850cf4af authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Compute `dropout_probability` only in training mode (#24486)



* fix

* fix

* fix

* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 9895670e
...@@ -1197,8 +1197,13 @@ class AutoformerEncoder(AutoformerPreTrainedModel): ...@@ -1197,8 +1197,13 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -1407,9 +1412,10 @@ class AutoformerDecoder(AutoformerPreTrainedModel): ...@@ -1407,9 +1412,10 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -836,8 +836,13 @@ class BartEncoder(BartPretrainedModel): ...@@ -836,8 +836,13 @@ class BartEncoder(BartPretrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -1089,9 +1094,10 @@ class BartDecoder(BartPretrainedModel): ...@@ -1089,9 +1094,10 @@ class BartDecoder(BartPretrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -1932,8 +1932,13 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): ...@@ -1932,8 +1932,13 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -2275,9 +2280,10 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2275,9 +2280,10 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -578,9 +578,10 @@ class BioGptModel(BioGptPreTrainedModel): ...@@ -578,9 +578,10 @@ class BioGptModel(BioGptPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -766,8 +766,13 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): ...@@ -766,8 +766,13 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -1018,9 +1023,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -1018,9 +1023,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -764,8 +764,13 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): ...@@ -764,8 +764,13 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -1015,9 +1020,10 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -1015,9 +1020,10 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -1223,8 +1223,13 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel): ...@@ -1223,8 +1223,13 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
# we add position_embeddings as extra input to the encoder_layer # we add position_embeddings as extra input to the encoder_layer
...@@ -1377,9 +1382,10 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel): ...@@ -1377,9 +1382,10 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
if idx == 0: if idx == 0:
pos_transformation = 1 pos_transformation = 1
else: else:
......
...@@ -978,8 +978,13 @@ class DetrEncoder(DetrPreTrainedModel): ...@@ -978,8 +978,13 @@ class DetrEncoder(DetrPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
# we add position_embeddings as extra input to the encoder_layer # we add position_embeddings as extra input to the encoder_layer
...@@ -1117,9 +1122,10 @@ class DetrDecoder(DetrPreTrainedModel): ...@@ -1117,9 +1122,10 @@ class DetrDecoder(DetrPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
......
...@@ -579,9 +579,10 @@ class FlaubertModel(FlaubertPreTrainedModel): ...@@ -579,9 +579,10 @@ class FlaubertModel(FlaubertPreTrainedModel):
attentions = () if output_attentions else None attentions = () if output_attentions else None
for i in range(self.n_layers): for i in range(self.n_layers):
# LayerDrop # LayerDrop
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
if output_hidden_states: if output_hidden_states:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
......
...@@ -793,9 +793,10 @@ class FSMTDecoder(nn.Module): ...@@ -793,9 +793,10 @@ class FSMTDecoder(nn.Module):
x = x.transpose(0, 1) x = x.transpose(0, 1)
all_hidden_states += (x,) all_hidden_states += (x,)
x = x.transpose(0, 1) x = x.transpose(0, 1)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
layer_state = past_key_values[idx] if past_key_values is not None else None layer_state = past_key_values[idx] if past_key_values is not None else None
......
...@@ -1204,8 +1204,13 @@ class InformerEncoder(InformerPreTrainedModel): ...@@ -1204,8 +1204,13 @@ class InformerEncoder(InformerPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -1424,9 +1429,10 @@ class InformerDecoder(InformerPreTrainedModel): ...@@ -1424,9 +1429,10 @@ class InformerDecoder(InformerPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -2134,9 +2134,10 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2134,9 +2134,10 @@ class LEDDecoder(LEDPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -777,8 +777,13 @@ class MarianEncoder(MarianPreTrainedModel): ...@@ -777,8 +777,13 @@ class MarianEncoder(MarianPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -1023,9 +1028,10 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -1023,9 +1028,10 @@ class MarianDecoder(MarianPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -763,9 +763,10 @@ class DetrDecoder(nn.Module): ...@@ -763,9 +763,10 @@ class DetrDecoder(nn.Module):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
......
...@@ -818,8 +818,13 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -818,8 +818,13 @@ class MBartEncoder(MBartPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -1073,9 +1078,10 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1073,9 +1078,10 @@ class MBartDecoder(MBartPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -940,8 +940,13 @@ class MvpEncoder(MvpPreTrainedModel): ...@@ -940,8 +940,13 @@ class MvpEncoder(MvpPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -1215,9 +1220,10 @@ class MvpDecoder(MvpPreTrainedModel): ...@@ -1215,9 +1220,10 @@ class MvpDecoder(MvpPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -684,9 +684,10 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -684,9 +684,10 @@ class OPTDecoder(OPTPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -792,8 +792,13 @@ class PegasusEncoder(PegasusPreTrainedModel): ...@@ -792,8 +792,13 @@ class PegasusEncoder(PegasusPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -1073,9 +1078,10 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -1073,9 +1078,10 @@ class PegasusDecoder(PegasusPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -1059,8 +1059,13 @@ class PegasusXEncoder(PegasusXPreTrainedModel): ...@@ -1059,8 +1059,13 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -1314,9 +1319,10 @@ class PegasusXDecoder(PegasusXPreTrainedModel): ...@@ -1314,9 +1319,10 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
...@@ -797,8 +797,13 @@ class PLBartEncoder(PLBartPreTrainedModel): ...@@ -797,8 +797,13 @@ class PLBartEncoder(PLBartPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = torch.rand([]) to_drop = False
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -1051,9 +1056,10 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -1051,9 +1056,10 @@ class PLBartDecoder(PLBartPreTrainedModel):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = torch.rand([]) if self.training:
if self.training and (dropout_probability < self.layerdrop): dropout_probability = torch.rand([])
continue if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment