Unverified Commit ddbb485c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TF-PT-Tests] Fix PyTorch - TF tests for different GPU devices (#15846)

parent 97f9b8a2
......@@ -1493,9 +1493,8 @@ class ModelTesterMixin:
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
# Make sure PyTorch tensors are on same device as model
pt_inputs = {k: v.to(torch_device) if torch.is_tensor(v) else v for k, v in pt_inputs.items()}
with torch.no_grad():
pto = pt_model(**pt_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment