Unverified Commit af5c3329 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

remove "inputs" in tf common test script (no longer required) (#15262)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent d12ae816
...@@ -377,10 +377,6 @@ class TFModelTesterMixin: ...@@ -377,10 +377,6 @@ class TFModelTesterMixin:
else: else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad(): with torch.no_grad():
pto = pt_model(**pt_inputs_dict) pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False) tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
...@@ -422,9 +418,6 @@ class TFModelTesterMixin: ...@@ -422,9 +418,6 @@ class TFModelTesterMixin:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else: else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad(): with torch.no_grad():
pto = pt_model(**pt_inputs_dict) pto = pt_model(**pt_inputs_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment