"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "1215c4201f44e47efb9ac402fc77885a6d12fa9a"
Commit 9a7fbd0b authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[translation] Fix export by fixing export utils so that by default is used...

[translation] Fix export by fixing export utils so that by default is used model.call and not model.__call__ which contains keras extra tracing logic.

PiperOrigin-RevId: 430546354
parent 04585eca
...@@ -68,8 +68,17 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -68,8 +68,17 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
if inference_step is not None: if inference_step is not None:
self.inference_step = functools.partial(inference_step, model=self.model) self.inference_step = functools.partial(inference_step, model=self.model)
else: else:
self.inference_step = functools.partial( if issubclass(type(model), tf.keras.Model):
self.model.__call__, training=False) # Default to self.model.call instead of self.model.__call__ to avoid
# keras tracing logic designed for training.
# Since most of Model Garden's call doesn't not have training kwargs
# or the default is False, we don't pass anything here.
# Please pass custom inference step if your model has training=True as
# default.
self.inference_step = self.model.call
else:
self.inference_step = functools.partial(
self.model.__call__, training=False)
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.postprocessor = postprocessor self.postprocessor = postprocessor
......
...@@ -20,6 +20,7 @@ from absl.testing import parameterized ...@@ -20,6 +20,7 @@ from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from sentencepiece import SentencePieceTrainer from sentencepiece import SentencePieceTrainer
from official.core import export_base
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.serving import serving_modules from official.nlp.serving import serving_modules
...@@ -369,5 +370,18 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase): ...@@ -369,5 +370,18 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(outputs.shape, (2,)) self.assertEqual(outputs.shape, (2,))
self.assertEqual(outputs.dtype, tf.string) self.assertEqual(outputs.dtype, tf.string)
tmp_dir = self.get_temp_dir()
export_base_dir = os.path.join(tmp_dir, "export")
ckpt_dir = os.path.join(tmp_dir, "ckpt")
ckpt_path = tf.train.Checkpoint(model=model).save(ckpt_dir)
export_dir = export_base.export(export_module,
{"serve_text": "serving_default"},
export_base_dir, ckpt_path)
loaded = tf.saved_model.load(export_dir)
infer = loaded.signatures["serving_default"]
out = infer(text=tf.constant(["abcd", "ef gh"]))
self.assertLen(out["output_0"], 2)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment