You need to sign in or sign up before continuing.
Commit 4f8426b1 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Update export base module to have preprocessor/postprocessor. Upstream the...

Update export base module to have preprocessor/postprocessor. Upstream the design from vision module.

PiperOrigin-RevId: 409203520
parent cb36903f
...@@ -28,13 +28,34 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -28,13 +28,34 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
def __init__(self, def __init__(self,
params, params,
model: Union[tf.Module, tf.keras.Model], model: Union[tf.Module, tf.keras.Model],
inference_step: Optional[Callable[..., Any]] = None): preprocessor: Optional[Callable[..., Any]] = None,
inference_step: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None):
"""Instantiates an ExportModel. """Instantiates an ExportModel.
Examples:
`inference_step` must be a function that has `model` as an kwarg or the
second positional argument.
```
def _inference_step(inputs, model=None):
return model(inputs, training=False)
module = ExportModule(params, model, inference_step=_inference_step)
```
`preprocessor` and `postprocessor` could be either functions or `tf.Module`.
The usages of preprocessor and postprocessor are managed by the
implementation of `serve()` method.
Args: Args:
params: A dataclass for parameters to the module. params: A dataclass for parameters to the module.
model: A model instance which contains weights and forward computation. model: A model instance which contains weights and forward computation.
inference_step: An optional callable to define how the model is called. preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable to forward-pass the model. If not
specified, it creates a parital function with `model` as an required
kwarg.
postprocessor: An optional callable to postprocess the model outputs.
""" """
super().__init__(name=None) super().__init__(name=None)
self.model = model self.model = model
...@@ -45,6 +66,8 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -45,6 +66,8 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
else: else:
self.inference_step = functools.partial( self.inference_step = functools.partial(
self.model.__call__, training=False) self.model.__call__, training=False)
self.preprocessor = preprocessor
self.postprocessor = postprocessor
@abc.abstractmethod @abc.abstractmethod
def serve(self) -> Mapping[Text, tf.Tensor]: def serve(self) -> Mapping[Text, tf.Tensor]:
......
...@@ -25,7 +25,11 @@ class TestModule(export_base.ExportModule): ...@@ -25,7 +25,11 @@ class TestModule(export_base.ExportModule):
@tf.function @tf.function
def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]: def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]:
return {'outputs': self.inference_step(inputs)} x = inputs if self.preprocessor is None else self.preprocessor(
inputs=inputs)
x = self.inference_step(x)
x = self.postprocessor(x) if self.postprocessor else x
return {'outputs': x}
def get_inference_signatures( def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]: self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
...@@ -83,6 +87,40 @@ class ExportBaseTest(tf.test.TestCase): ...@@ -83,6 +87,40 @@ class ExportBaseTest(tf.test.TestCase):
output = imported.signatures['foo'](inputs) output = imported.signatures['foo'](inputs)
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy()) self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
def test_processors(self):
model = tf.Module()
inputs = tf.zeros((), tf.float32)
def _inference_step(inputs, model):
del model
return inputs + 1.0
def _preprocessor(inputs):
print(inputs)
return inputs + 0.1
module = TestModule(
params=None,
model=model,
inference_step=_inference_step,
preprocessor=_preprocessor)
output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.1)
class _PostProcessor(tf.Module):
def __call__(self, inputs):
return inputs + 0.01
module = TestModule(
params=None,
model=model,
inference_step=_inference_step,
preprocessor=_preprocessor,
postprocessor=_PostProcessor())
output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.11)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -38,14 +38,16 @@ class ExportModule(export_base.ExportModule): ...@@ -38,14 +38,16 @@ class ExportModule(export_base.ExportModule):
model: A tf.keras.Model instance to be exported. model: A tf.keras.Model instance to be exported.
input_signature: tf.TensorSpec, e.g. input_signature: tf.TensorSpec, e.g.
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8) tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
preprocessor: An optional callable function to preprocess the inputs. preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable function to forward-pass the model. inference_step: An optional callable to forward-pass the model.
postprocessor: An optional callable function to postprocess the model postprocessor: An optional callable to postprocess the model outputs.
outputs.
""" """
super().__init__(params, model=model, inference_step=inference_step) super().__init__(
self.preprocessor = preprocessor params,
self.postprocessor = postprocessor model=model,
preprocessor=preprocessor,
inference_step=inference_step,
postprocessor=postprocessor)
self.input_signature = input_signature self.input_signature = input_signature
@tf.function @tf.function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment