Commit 4f8426b1 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Update export base module to have preprocessor/postprocessor. Upstream the...

Update export base module to have preprocessor/postprocessor. Upstream the design from vision module.

PiperOrigin-RevId: 409203520
parent cb36903f
...@@ -28,13 +28,34 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -28,13 +28,34 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
def __init__(self, def __init__(self,
params, params,
model: Union[tf.Module, tf.keras.Model], model: Union[tf.Module, tf.keras.Model],
inference_step: Optional[Callable[..., Any]] = None): preprocessor: Optional[Callable[..., Any]] = None,
inference_step: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None):
"""Instantiates an ExportModel. """Instantiates an ExportModel.
Examples:
`inference_step` must be a function that has `model` as an kwarg or the
second positional argument.
```
def _inference_step(inputs, model=None):
return model(inputs, training=False)
module = ExportModule(params, model, inference_step=_inference_step)
```
`preprocessor` and `postprocessor` could be either functions or `tf.Module`.
The usages of preprocessor and postprocessor are managed by the
implementation of `serve()` method.
Args: Args:
params: A dataclass for parameters to the module. params: A dataclass for parameters to the module.
model: A model instance which contains weights and forward computation. model: A model instance which contains weights and forward computation.
inference_step: An optional callable to define how the model is called. preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable to forward-pass the model. If not
specified, it creates a parital function with `model` as an required
kwarg.
postprocessor: An optional callable to postprocess the model outputs.
""" """
super().__init__(name=None) super().__init__(name=None)
self.model = model self.model = model
...@@ -45,6 +66,8 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -45,6 +66,8 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
else: else:
self.inference_step = functools.partial( self.inference_step = functools.partial(
self.model.__call__, training=False) self.model.__call__, training=False)
self.preprocessor = preprocessor
self.postprocessor = postprocessor
@abc.abstractmethod @abc.abstractmethod
def serve(self) -> Mapping[Text, tf.Tensor]: def serve(self) -> Mapping[Text, tf.Tensor]:
......
...@@ -25,7 +25,11 @@ class TestModule(export_base.ExportModule): ...@@ -25,7 +25,11 @@ class TestModule(export_base.ExportModule):
@tf.function @tf.function
def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]: def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]:
return {'outputs': self.inference_step(inputs)} x = inputs if self.preprocessor is None else self.preprocessor(
inputs=inputs)
x = self.inference_step(x)
x = self.postprocessor(x) if self.postprocessor else x
return {'outputs': x}
def get_inference_signatures( def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]: self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
...@@ -83,6 +87,40 @@ class ExportBaseTest(tf.test.TestCase): ...@@ -83,6 +87,40 @@ class ExportBaseTest(tf.test.TestCase):
output = imported.signatures['foo'](inputs) output = imported.signatures['foo'](inputs)
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy()) self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
def test_processors(self):
model = tf.Module()
inputs = tf.zeros((), tf.float32)
def _inference_step(inputs, model):
del model
return inputs + 1.0
def _preprocessor(inputs):
print(inputs)
return inputs + 0.1
module = TestModule(
params=None,
model=model,
inference_step=_inference_step,
preprocessor=_preprocessor)
output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.1)
class _PostProcessor(tf.Module):
def __call__(self, inputs):
return inputs + 0.01
module = TestModule(
params=None,
model=model,
inference_step=_inference_step,
preprocessor=_preprocessor,
postprocessor=_PostProcessor())
output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.11)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -38,14 +38,16 @@ class ExportModule(export_base.ExportModule): ...@@ -38,14 +38,16 @@ class ExportModule(export_base.ExportModule):
model: A tf.keras.Model instance to be exported. model: A tf.keras.Model instance to be exported.
input_signature: tf.TensorSpec, e.g. input_signature: tf.TensorSpec, e.g.
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8) tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
preprocessor: An optional callable function to preprocess the inputs. preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable function to forward-pass the model. inference_step: An optional callable to forward-pass the model.
postprocessor: An optional callable function to postprocess the model postprocessor: An optional callable to postprocess the model outputs.
outputs.
""" """
super().__init__(params, model=model, inference_step=inference_step) super().__init__(
self.preprocessor = preprocessor params,
self.postprocessor = postprocessor model=model,
preprocessor=preprocessor,
inference_step=inference_step,
postprocessor=postprocessor)
self.input_signature = input_signature self.input_signature = input_signature
@tf.function @tf.function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment