Unverified Commit fdb2351e authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Making TF XLM-like models XLA and AMP compliant (#10211)

* Fix Flaubert and XLM

* Remove useless cast

* Tiny fix

* Tiny fix
parent 83d803ba
...@@ -171,7 +171,7 @@ FLAUBERT_INPUTS_DOCSTRING = r""" ...@@ -171,7 +171,7 @@ FLAUBERT_INPUTS_DOCSTRING = r"""
""" """
def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): def get_masks(slen, lengths, causal, padding_mask=None):
""" """
Generate hidden states mask, and optionally an attention mask. Generate hidden states mask, and optionally an attention mask.
""" """
...@@ -193,12 +193,10 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): ...@@ -193,12 +193,10 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
# sanity check # sanity check
# assert shape_list(mask) == [bs, slen] # assert shape_list(mask) == [bs, slen]
if tf.executing_eagerly():
tf.debugging.assert_equal(shape_list(mask), [bs, slen]) tf.debugging.assert_equal(shape_list(mask), [bs, slen])
assert causal is False or shape_list(attn_mask) == [bs, slen, slen] assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
mask = tf.cast(mask, dtype=dtype)
attn_mask = tf.cast(attn_mask, dtype=dtype)
return mask, attn_mask return mask, attn_mask
...@@ -339,8 +337,7 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer): ...@@ -339,8 +337,7 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
klen = shape_list(kv)[1] klen = shape_list(kv)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
dim_per_head = tf.math.divide(self.dim, self.n_heads) dim_per_head = self.dim // self.n_heads
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen) mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
def shape(x): def shape(x):
...@@ -372,8 +369,8 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer): ...@@ -372,8 +369,8 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
cache[self.layer_id] = (k, v) cache[self.layer_id] = (k, v)
q = tf.cast(q, dtype=tf.float32) f_dim_per_head = tf.cast(dim_per_head, dtype=q.dtype)
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head) q = tf.multiply(q, tf.math.rsqrt(f_dim_per_head)) # (bs, n_heads, qlen, dim_per_head)
k = tf.cast(k, dtype=q.dtype) k = tf.cast(k, dtype=q.dtype)
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen) scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
...@@ -438,22 +435,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -438,22 +435,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict self.return_dict = config.use_return_dict
self.max_position_embeddings = config.max_position_embeddings
self.embed_init_std = config.embed_init_std
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings,
self.dim,
embeddings_initializer=get_initializer(config.embed_init_std),
name="position_embeddings",
)
if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = tf.keras.layers.Embedding(
self.n_langs,
self.dim,
embeddings_initializer=get_initializer(config.embed_init_std),
name="lang_embeddings",
)
self.embeddings = TFSharedEmbeddings( self.embeddings = TFSharedEmbeddings(
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings" self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
) )
...@@ -482,6 +466,24 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -482,6 +466,24 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i)) tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i))
) )
def build(self, input_shape):
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.dim],
initializer=get_initializer(self.embed_init_std),
)
if self.n_langs > 1 and self.use_lang_emb:
with tf.name_scope("lang_embeddings"):
self.lang_embeddings = self.add_weight(
name="embeddings",
shape=[self.n_langs, self.dim],
initializer=get_initializer(self.embed_init_std),
)
super().build(input_shape)
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
...@@ -538,14 +540,15 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -538,14 +540,15 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
if inputs["lengths"] is None: if inputs["lengths"] is None:
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
inputs["lengths"] = tf.reduce_sum( inputs["lengths"] = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1 tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=inputs["input_ids"].dtype), axis=1
) )
else: else:
inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32) inputs["lengths"] = tf.convert_to_tensor([slen] * bs)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["lengths"])[0], bs shape_list(inputs["lengths"])[0], bs
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched" ), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
...@@ -564,7 +567,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -564,7 +567,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# position_ids # position_ids
if inputs["position_ids"] is None: if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0) inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
else: inputs["position_ids"] = tf.tile(inputs["position_ids"], (bs, 1))
if tf.executing_eagerly():
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["position_ids"]), [bs, slen] shape_list(inputs["position_ids"]), [bs, slen]
...@@ -572,7 +577,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -572,7 +577,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
# langs # langs
if inputs["langs"] is not None: if inputs["langs"] is not None and tf.executing_eagerly():
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["langs"]), [bs, slen] shape_list(inputs["langs"]), [bs, slen]
...@@ -603,15 +608,16 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -603,15 +608,16 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
if inputs["inputs_embeds"] is None: if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"]) inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"]) tensor = inputs["inputs_embeds"] + tf.gather(self.position_embeddings, inputs["position_ids"])
if inputs["langs"] is not None and self.use_lang_emb: if inputs["langs"] is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(inputs["langs"]) tensor = tensor + tf.gather(self.lang_embeddings, inputs["langs"])
if inputs["token_type_ids"] is not None: if inputs["token_type_ids"] is not None:
tensor = tensor + self.embeddings(inputs["token_type_ids"]) tensor = tensor + self.embeddings(inputs["token_type_ids"])
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=inputs["training"]) tensor = self.dropout(tensor, training=inputs["training"])
mask = tf.cast(mask, dtype=tensor.dtype)
tensor = tensor * tf.expand_dims(mask, axis=-1) tensor = tensor * tf.expand_dims(mask, axis=-1)
# hidden_states and attentions cannot be None in graph mode. # hidden_states and attentions cannot be None in graph mode.
...@@ -804,7 +810,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): ...@@ -804,7 +810,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
lang_id = self.config.lang_id lang_id = self.config.lang_id
effective_batch_size = inputs.shape[0] effective_batch_size = inputs.shape[0]
mask_token = tf.ones((effective_batch_size, 1), dtype=tf.int32) * mask_token_id mask_token = tf.fill((effective_batch_size, 1), 1) * mask_token_id
inputs = tf.concat([inputs, mask_token], axis=1) inputs = tf.concat([inputs, mask_token], axis=1)
if lang_id is not None: if lang_id is not None:
......
...@@ -82,7 +82,7 @@ def create_sinusoidal_embeddings(n_pos, dim, out): ...@@ -82,7 +82,7 @@ def create_sinusoidal_embeddings(n_pos, dim, out):
out[:, 1::2] = tf.constant(np.cos(position_enc[:, 1::2])) out[:, 1::2] = tf.constant(np.cos(position_enc[:, 1::2]))
def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): def get_masks(slen, lengths, causal, padding_mask=None):
""" """
Generate hidden states mask, and optionally an attention mask. Generate hidden states mask, and optionally an attention mask.
""" """
...@@ -104,12 +104,10 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): ...@@ -104,12 +104,10 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
# sanity check # sanity check
# assert shape_list(mask) == [bs, slen] # assert shape_list(mask) == [bs, slen]
if tf.executing_eagerly():
tf.debugging.assert_equal(shape_list(mask), [bs, slen]) tf.debugging.assert_equal(shape_list(mask), [bs, slen])
assert causal is False or shape_list(attn_mask) == [bs, slen, slen] assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
mask = tf.cast(mask, dtype=dtype)
attn_mask = tf.cast(attn_mask, dtype=dtype)
return mask, attn_mask return mask, attn_mask
...@@ -148,8 +146,7 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer): ...@@ -148,8 +146,7 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
klen = shape_list(kv)[1] klen = shape_list(kv)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
dim_per_head = tf.math.divide(self.dim, self.n_heads) dim_per_head = self.dim // self.n_heads
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen) mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
def shape(x): def shape(x):
...@@ -181,8 +178,8 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer): ...@@ -181,8 +178,8 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
cache[self.layer_id] = (k, v) cache[self.layer_id] = (k, v)
q = tf.cast(q, dtype=tf.float32) f_dim_per_head = tf.cast(dim_per_head, dtype=q.dtype)
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head) q = tf.multiply(q, tf.math.rsqrt(f_dim_per_head)) # (bs, n_heads, qlen, dim_per_head)
k = tf.cast(k, dtype=q.dtype) k = tf.cast(k, dtype=q.dtype)
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen) scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
...@@ -263,30 +260,18 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -263,30 +260,18 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
self.hidden_dim = self.dim * 4 # 2048 by default self.hidden_dim = self.dim * 4 # 2048 by default
self.n_heads = config.n_heads # 8 by default self.n_heads = config.n_heads # 8 by default
self.n_layers = config.n_layers self.n_layers = config.n_layers
self.max_position_embeddings = config.max_position_embeddings
self.embed_init_std = config.embed_init_std
assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads" assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads"
# embeddings # embeddings
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
self.attention_dropout = tf.keras.layers.Dropout(config.attention_dropout) self.attention_dropout = tf.keras.layers.Dropout(config.attention_dropout)
self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings,
self.dim,
embeddings_initializer=get_initializer(config.embed_init_std),
name="position_embeddings",
)
if config.sinusoidal_embeddings: if config.sinusoidal_embeddings:
raise NotImplementedError raise NotImplementedError
# create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) # create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = tf.keras.layers.Embedding(
self.n_langs,
self.dim,
embeddings_initializer=get_initializer(config.embed_init_std),
name="lang_embeddings",
)
self.embeddings = TFSharedEmbeddings( self.embeddings = TFSharedEmbeddings(
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings" self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
) # padding_idx=self.pad_index) ) # padding_idx=self.pad_index)
...@@ -326,6 +311,24 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -326,6 +311,24 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if self.attentions[int(layer)].n_heads == config.n_heads: if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))}) self.prune_heads({int(layer): list(map(int, heads))})
def build(self, input_shape):
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.dim],
initializer=get_initializer(self.embed_init_std),
)
if self.n_langs > 1 and self.use_lang_emb:
with tf.name_scope("lang_embeddings"):
self.lang_embeddings = self.add_weight(
name="embeddings",
shape=[self.n_langs, self.dim],
initializer=get_initializer(self.embed_init_std),
)
super().build(input_shape)
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings return self.embeddings
...@@ -389,14 +392,15 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -389,14 +392,15 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if inputs["lengths"] is None: if inputs["lengths"] is None:
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
inputs["lengths"] = tf.reduce_sum( inputs["lengths"] = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1 tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=inputs["input_ids"].dtype), axis=1
) )
else: else:
inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32) inputs["lengths"] = tf.convert_to_tensor([slen] * bs)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
if tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["lengths"])[0], bs shape_list(inputs["lengths"])[0], bs
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched" ), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
...@@ -415,7 +419,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -415,7 +419,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# position_ids # position_ids
if inputs["position_ids"] is None: if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0) inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
else: inputs["position_ids"] = tf.tile(inputs["position_ids"], (bs, 1))
if tf.executing_eagerly():
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["position_ids"]), [bs, slen] shape_list(inputs["position_ids"]), [bs, slen]
...@@ -423,7 +429,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -423,7 +429,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
# langs # langs
if inputs["langs"] is not None: if inputs["langs"] is not None and tf.executing_eagerly():
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["langs"]), [bs, slen] shape_list(inputs["langs"]), [bs, slen]
...@@ -454,15 +460,16 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -454,15 +460,16 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if inputs["inputs_embeds"] is None: if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"]) inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"]) tensor = inputs["inputs_embeds"] + tf.gather(self.position_embeddings, inputs["position_ids"])
if inputs["langs"] is not None and self.use_lang_emb and self.n_langs > 1: if inputs["langs"] is not None and self.use_lang_emb and self.n_langs > 1:
tensor = tensor + self.lang_embeddings(inputs["langs"]) tensor = tensor + tf.gather(self.lang_embeddings, inputs["langs"])
if inputs["token_type_ids"] is not None: if inputs["token_type_ids"] is not None:
tensor = tensor + self.embeddings(inputs["token_type_ids"]) tensor = tensor + self.embeddings(inputs["token_type_ids"])
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=inputs["training"]) tensor = self.dropout(tensor, training=inputs["training"])
mask = tf.cast(mask, dtype=tensor.dtype)
tensor = tensor * tf.expand_dims(mask, axis=-1) tensor = tensor * tf.expand_dims(mask, axis=-1)
# transformer layers # transformer layers
...@@ -837,7 +844,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -837,7 +844,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
lang_id = self.config.lang_id lang_id = self.config.lang_id
effective_batch_size = inputs.shape[0] effective_batch_size = inputs.shape[0]
mask_token = tf.ones((effective_batch_size, 1), dtype=tf.int32) * mask_token_id mask_token = tf.fill((effective_batch_size, 1), 1) * mask_token_id
inputs = tf.concat([inputs, mask_token], axis=1) inputs = tf.concat([inputs, mask_token], axis=1)
if lang_id is not None: if lang_id is not None:
......
...@@ -331,14 +331,6 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -331,14 +331,6 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFFlaubertModel.from_pretrained(model_name) model = TFFlaubertModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_mixed_precision(self):
# TODO JP: Make Flaubert float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Flaubert XLA compliant
pass
@require_tf @require_tf
@require_sentencepiece @require_sentencepiece
......
...@@ -327,14 +327,6 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -327,14 +327,6 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs) self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
def test_mixed_precision(self):
# TODO JP: Make XLM float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make XLM XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment