"examples/git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "267e25a750cb2e44e48206408f96d60b3dabc0b9"
Commit 4f8426b1 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Update export base module to have preprocessor/postprocessor. Upstream the...

Update export base module to have preprocessor/postprocessor. Upstream the design from vision module.

PiperOrigin-RevId: 409203520
parent cb36903f
......@@ -28,13 +28,34 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
def __init__(self,
params,
model: Union[tf.Module, tf.keras.Model],
inference_step: Optional[Callable[..., Any]] = None):
preprocessor: Optional[Callable[..., Any]] = None,
inference_step: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None):
"""Instantiates an ExportModel.
Examples:
`inference_step` must be a function that has `model` as an kwarg or the
second positional argument.
```
def _inference_step(inputs, model=None):
return model(inputs, training=False)
module = ExportModule(params, model, inference_step=_inference_step)
```
`preprocessor` and `postprocessor` could be either functions or `tf.Module`.
The usages of preprocessor and postprocessor are managed by the
implementation of `serve()` method.
Args:
params: A dataclass for parameters to the module.
model: A model instance which contains weights and forward computation.
inference_step: An optional callable to define how the model is called.
preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable to forward-pass the model. If not
specified, it creates a parital function with `model` as an required
kwarg.
postprocessor: An optional callable to postprocess the model outputs.
"""
super().__init__(name=None)
self.model = model
......@@ -45,6 +66,8 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
else:
self.inference_step = functools.partial(
self.model.__call__, training=False)
self.preprocessor = preprocessor
self.postprocessor = postprocessor
@abc.abstractmethod
def serve(self) -> Mapping[Text, tf.Tensor]:
......
......@@ -25,7 +25,11 @@ class TestModule(export_base.ExportModule):
@tf.function
def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]:
return {'outputs': self.inference_step(inputs)}
x = inputs if self.preprocessor is None else self.preprocessor(
inputs=inputs)
x = self.inference_step(x)
x = self.postprocessor(x) if self.postprocessor else x
return {'outputs': x}
def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
......@@ -83,6 +87,40 @@ class ExportBaseTest(tf.test.TestCase):
output = imported.signatures['foo'](inputs)
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
def test_processors(self):
model = tf.Module()
inputs = tf.zeros((), tf.float32)
def _inference_step(inputs, model):
del model
return inputs + 1.0
def _preprocessor(inputs):
print(inputs)
return inputs + 0.1
module = TestModule(
params=None,
model=model,
inference_step=_inference_step,
preprocessor=_preprocessor)
output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.1)
class _PostProcessor(tf.Module):
def __call__(self, inputs):
return inputs + 0.01
module = TestModule(
params=None,
model=model,
inference_step=_inference_step,
preprocessor=_preprocessor,
postprocessor=_PostProcessor())
output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.11)
if __name__ == '__main__':
tf.test.main()
......@@ -38,14 +38,16 @@ class ExportModule(export_base.ExportModule):
model: A tf.keras.Model instance to be exported.
input_signature: tf.TensorSpec, e.g.
tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
preprocessor: An optional callable function to preprocess the inputs.
inference_step: An optional callable function to forward-pass the model.
postprocessor: An optional callable function to postprocess the model
outputs.
preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable to forward-pass the model.
postprocessor: An optional callable to postprocess the model outputs.
"""
super().__init__(params, model=model, inference_step=inference_step)
self.preprocessor = preprocessor
self.postprocessor = postprocessor
super().__init__(
params,
model=model,
preprocessor=preprocessor,
inference_step=inference_step,
postprocessor=postprocessor)
self.input_signature = input_signature
@tf.function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment