"git@developer.sourcefind.cn:OpenDAS/torch-spline-conv.git" did not exist on "8fa47ae8093b3e3006e6f38ab7fe90fe3e687edb"
Commit 60705b10 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[seq2seq] hardcode batch size

PiperOrigin-RevId: 430563213
parent 9a7fbd0b
...@@ -417,6 +417,8 @@ class Translation(export_base.ExportModule): ...@@ -417,6 +417,8 @@ class Translation(export_base.ExportModule):
@dataclasses.dataclass @dataclasses.dataclass
class Params(base_config.Config): class Params(base_config.Config):
sentencepiece_model_path: str = "" sentencepiece_model_path: str = ""
# Needs to be specified if padded_decode is True/on TPUs.
batch_size: Optional[int] = None
def __init__(self, params, model: tf.keras.Model, inference_step=None): def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step) super().__init__(params, model, inference_step)
...@@ -431,6 +433,7 @@ class Translation(export_base.ExportModule): ...@@ -431,6 +433,7 @@ class Translation(export_base.ExportModule):
"Please make sure the tokenizer generates a single token for an " "Please make sure the tokenizer generates a single token for an "
"empty string.") "empty string.")
self._eos_id = empty_str_tokenized.item() self._eos_id = empty_str_tokenized.item()
self._batch_size = params.batch_size
@tf.function @tf.function
def serve(self, inputs) -> Dict[str, tf.Tensor]: def serve(self, inputs) -> Dict[str, tf.Tensor]:
...@@ -452,5 +455,6 @@ class Translation(export_base.ExportModule): ...@@ -452,5 +455,6 @@ class Translation(export_base.ExportModule):
(self.__class__, func_key, valid_keys)) (self.__class__, func_key, valid_keys))
if func_key == "serve_text": if func_key == "serve_text":
signatures[signature_key] = self.serve_text.get_concrete_function( signatures[signature_key] = self.serve_text.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="text")) tf.TensorSpec(shape=[self._batch_size],
dtype=tf.string, name="text"))
return signatures return signatures
...@@ -344,7 +344,10 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase): ...@@ -344,7 +344,10 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = export_module.get_inference_signatures({"foo": None}) _ = export_module.get_inference_signatures({"foo": None})
def test_translation(self): @parameterized.parameters(
(False, None),
(True, 2))
def test_translation(self, padded_decode, batch_size):
sp_path = _make_sentencepeice(self.get_temp_dir()) sp_path = _make_sentencepeice(self.get_temp_dir())
encdecoder = translation.EncDecoder( encdecoder = translation.EncDecoder(
num_attention_heads=4, intermediate_size=256) num_attention_heads=4, intermediate_size=256)
...@@ -353,7 +356,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase): ...@@ -353,7 +356,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
encoder=encdecoder, encoder=encdecoder,
decoder=encdecoder, decoder=encdecoder,
embedding_width=256, embedding_width=256,
padded_decode=False, padded_decode=padded_decode,
decode_max_length=100), decode_max_length=100),
sentencepiece_model_path=sp_path, sentencepiece_model_path=sp_path,
) )
...@@ -361,7 +364,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase): ...@@ -361,7 +364,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
model = task.build_model() model = task.build_model()
params = serving_modules.Translation.Params( params = serving_modules.Translation.Params(
sentencepiece_model_path=sp_path) sentencepiece_model_path=sp_path, batch_size=batch_size)
export_module = serving_modules.Translation(params=params, model=model) export_module = serving_modules.Translation(params=params, model=model)
functions = export_module.get_inference_signatures({ functions = export_module.get_inference_signatures({
"serve_text": "serving_default" "serve_text": "serving_default"
...@@ -371,6 +374,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase): ...@@ -371,6 +374,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(outputs.dtype, tf.string) self.assertEqual(outputs.dtype, tf.string)
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
tmp_dir = os.path.join(tmp_dir, "padded_decode", str(padded_decode))
export_base_dir = os.path.join(tmp_dir, "export") export_base_dir = os.path.join(tmp_dir, "export")
ckpt_dir = os.path.join(tmp_dir, "ckpt") ckpt_dir = os.path.join(tmp_dir, "ckpt")
ckpt_path = tf.train.Checkpoint(model=model).save(ckpt_dir) ckpt_path = tf.train.Checkpoint(model=model).save(ckpt_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment