"ppocr/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "b9d5a3e756c5c4fd7e3c99f173b92de7e942949c"
Unverified Commit 009171d1 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Ensure PT model is in evaluation mode and lightweight forward pass done (#17970)

parent d6cec458
...@@ -145,7 +145,7 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -145,7 +145,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in # If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# recursivelly, keeping the name of the attribute. # recursivelly, keeping the name of the attribute.
if isinstance(pt_out, torch.Tensor): if isinstance(pt_out, torch.Tensor):
tensor_difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy())) tensor_difference = np.max(np.abs(pt_out.numpy() - tf_out.numpy()))
differences[attr_name] = tensor_difference differences[attr_name] = tensor_difference
else: else:
root_name = attr_name root_name = attr_name
...@@ -270,9 +270,13 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -270,9 +270,13 @@ class PTtoTFCommand(BaseTransformersCLICommand):
# Load models and acquire a basic input compatible with the model. # Load models and acquire a basic input compatible with the model.
pt_model = pt_class.from_pretrained(self._local_dir) pt_model = pt_class.from_pretrained(self._local_dir)
pt_model.eval()
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True) tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
pt_input, tf_input = self.get_inputs(pt_model, config) pt_input, tf_input = self.get_inputs(pt_model, config)
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
with torch.no_grad():
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
del pt_model # will no longer be used, and may have a large memory footprint del pt_model # will no longer be used, and may have a large memory footprint
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True) tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment