Unverified Commit d735b074 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix Flaubert (#9292)

parent 5dd389d1
......@@ -17,6 +17,7 @@
"""
import itertools
import random
from dataclasses import dataclass
from typing import Optional, Tuple
......@@ -596,15 +597,15 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
tensor = tensor * mask[..., tf.newaxis]
# hidden_states and attentions cannot be None in graph mode.
hidden_states = ()
attentions = ()
hidden_states = () if inputs["output_hidden_states"] else None
attentions = () if inputs["output_attentions"] else None
# transformer layers
for i in range(self.n_layers):
# LayerDrop
dropout_probability = tf.random.uniform([1], 0, 1)
dropout_probability = random.uniform(0, 1)
if inputs["training"] and tf.less(dropout_probability, self.layerdrop):
if inputs["training"] and (dropout_probability < self.layerdrop):
continue
if inputs["output_hidden_states"]:
......@@ -642,7 +643,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
)
attn = attn_outputs[0]
if output_attentions:
if inputs["output_attentions"]:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=inputs["training"])
......@@ -676,10 +677,6 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1)
# Set to None here if the output booleans are at False
hidden_states = hidden_states if inputs["output_hidden_states"] else None
attentions = attentions if inputs["output_attentions"] else None
if not inputs["return_dict"]:
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment