Unverified Commit d7c31abf authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix some TF slow tests (#9728)

* Fix saved model tests + fix a graph issue in longformer

* Apply style
parent 08b22722
...@@ -2438,10 +2438,16 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque ...@@ -2438,10 +2438,16 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
logger.info("Initializing global attention on CLS token...") logger.info("Initializing global attention on CLS token...")
# global attention on cls token # global attention on cls token
inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"]) inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"])
updates = tf.ones(shape_list(inputs["input_ids"])[0], dtype=tf.int32)
indices = tf.pad(
tensor=tf.expand_dims(tf.range(shape_list(inputs["input_ids"])[0]), axis=1),
paddings=[[0, 0], [0, 1]],
constant_values=0,
)
inputs["global_attention_mask"] = tf.tensor_scatter_nd_update( inputs["global_attention_mask"] = tf.tensor_scatter_nd_update(
inputs["global_attention_mask"], inputs["global_attention_mask"],
[[i, 0] for i in range(shape_list(inputs["input_ids"])[0])], indices,
[1 for _ in range(shape_list(inputs["input_ids"])[0])], updates,
) )
outputs = self.longformer( outputs = self.longformer(
......
...@@ -184,7 +184,7 @@ class TFModelTesterMixin: ...@@ -184,7 +184,7 @@ class TFModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True) model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model") saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
self.assertTrue(os.path.exists(saved_model_dir)) self.assertTrue(os.path.exists(saved_model_dir))
@slow @slow
...@@ -204,7 +204,7 @@ class TFModelTesterMixin: ...@@ -204,7 +204,7 @@ class TFModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True) model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model") saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
self.assertTrue(os.path.exists(saved_model_dir)) self.assertTrue(os.path.exists(saved_model_dir))
@slow @slow
...@@ -223,7 +223,8 @@ class TFModelTesterMixin: ...@@ -223,7 +223,8 @@ class TFModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True) model.save_pretrained(tmpdirname, saved_model=True)
model = tf.keras.models.load_model(os.path.join(tmpdirname, "saved_model", "1")) saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict) outputs = model(class_inputs_dict)
if self.is_encoder_decoder: if self.is_encoder_decoder:
...@@ -262,7 +263,8 @@ class TFModelTesterMixin: ...@@ -262,7 +263,8 @@ class TFModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True) model.save_pretrained(tmpdirname, saved_model=True)
model = tf.keras.models.load_model(os.path.join(tmpdirname, "saved_model", "1")) saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict) outputs = model(class_inputs_dict)
if self.is_encoder_decoder: if self.is_encoder_decoder:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment