Unverified Commit 3a1aeea3 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix CTRL `test_torchscrip_xxx` CI by updating `_create_and_check_torchscript` (#19786)



* Run inputs before trace

* Run inputs before trace
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 31565ff0
......@@ -658,6 +658,7 @@ class ModelTesterMixin:
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
......@@ -665,11 +666,13 @@ class ModelTesterMixin:
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
model(input_ids, bbox, image)
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
else:
main_input = inputs[main_input_name]
model(main_input)
traced_model = torch.jit.trace(model, main_input)
except RuntimeError:
self.fail("Couldn't trace module.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment