Commit b72f4975 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[seq2seq] hardcode batch size

PiperOrigin-RevId: 430563213
parent 0bcb7aa0
......@@ -417,6 +417,8 @@ class Translation(export_base.ExportModule):
@dataclasses.dataclass
class Params(base_config.Config):
sentencepiece_model_path: str = ""
# Needs to be specified if padded_decode is True/on TPUs.
batch_size: Optional[int] = None
def __init__(self, params, model: tf.keras.Model, inference_step=None):
super().__init__(params, model, inference_step)
......@@ -431,6 +433,7 @@ class Translation(export_base.ExportModule):
"Please make sure the tokenizer generates a single token for an "
"empty string.")
self._eos_id = empty_str_tokenized.item()
self._batch_size = params.batch_size
@tf.function
def serve(self, inputs) -> Dict[str, tf.Tensor]:
......@@ -452,5 +455,6 @@ class Translation(export_base.ExportModule):
(self.__class__, func_key, valid_keys))
if func_key == "serve_text":
signatures[signature_key] = self.serve_text.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="text"))
tf.TensorSpec(shape=[self._batch_size],
dtype=tf.string, name="text"))
return signatures
......@@ -344,7 +344,10 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(ValueError):
_ = export_module.get_inference_signatures({"foo": None})
def test_translation(self):
@parameterized.parameters(
(False, None),
(True, 2))
def test_translation(self, padded_decode, batch_size):
sp_path = _make_sentencepeice(self.get_temp_dir())
encdecoder = translation.EncDecoder(
num_attention_heads=4, intermediate_size=256)
......@@ -353,7 +356,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
encoder=encdecoder,
decoder=encdecoder,
embedding_width=256,
padded_decode=False,
padded_decode=padded_decode,
decode_max_length=100),
sentencepiece_model_path=sp_path,
)
......@@ -361,7 +364,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
model = task.build_model()
params = serving_modules.Translation.Params(
sentencepiece_model_path=sp_path)
sentencepiece_model_path=sp_path, batch_size=batch_size)
export_module = serving_modules.Translation(params=params, model=model)
functions = export_module.get_inference_signatures({
"serve_text": "serving_default"
......@@ -371,6 +374,7 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(outputs.dtype, tf.string)
tmp_dir = self.get_temp_dir()
tmp_dir = os.path.join(tmp_dir, "padded_decode", str(padded_decode))
export_base_dir = os.path.join(tmp_dir, "export")
ckpt_dir = os.path.join(tmp_dir, "ckpt")
ckpt_path = tf.train.Checkpoint(model=model).save(ckpt_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment