Unverified Commit fd673510 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Make PT/Flax tests could be run on GPU (#24557)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent faae8d82
...@@ -611,7 +611,7 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -611,7 +611,7 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
# convert inputs to Flax # convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple() fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
...@@ -669,7 +669,7 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -669,7 +669,7 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple() fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
......
...@@ -592,7 +592,7 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -592,7 +592,7 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
# convert inputs to Flax # convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple() fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
...@@ -650,7 +650,7 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -650,7 +650,7 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple() fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
......
...@@ -875,7 +875,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -875,7 +875,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
} }
# convert inputs to Flax # convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state fx_model.params = fx_state
...@@ -948,7 +948,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -948,7 +948,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
} }
# convert inputs to Flax # convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
...@@ -1805,7 +1805,7 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. ...@@ -1805,7 +1805,7 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
} }
# convert inputs to Flax # convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state fx_model.params = fx_state
...@@ -1878,7 +1878,7 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. ...@@ -1878,7 +1878,7 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
} }
# convert inputs to Flax # convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
......
...@@ -2206,7 +2206,7 @@ class ModelTesterMixin: ...@@ -2206,7 +2206,7 @@ class ModelTesterMixin:
} }
# convert inputs to Flax # convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state fx_model.params = fx_state
...@@ -2278,7 +2278,7 @@ class ModelTesterMixin: ...@@ -2278,7 +2278,7 @@ class ModelTesterMixin:
} }
# convert inputs to Flax # convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment