Commit 4b5560cd authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 409275976
parent 4f8426b1
...@@ -28,8 +28,9 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -28,8 +28,9 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
def __init__(self, def __init__(self,
params, params,
model: Union[tf.Module, tf.keras.Model], model: Union[tf.Module, tf.keras.Model],
preprocessor: Optional[Callable[..., Any]] = None,
inference_step: Optional[Callable[..., Any]] = None, inference_step: Optional[Callable[..., Any]] = None,
*,
preprocessor: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None): postprocessor: Optional[Callable[..., Any]] = None):
"""Instantiates an ExportModel. """Instantiates an ExportModel.
...@@ -51,10 +52,10 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -51,10 +52,10 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
Args: Args:
params: A dataclass for parameters to the module. params: A dataclass for parameters to the module.
model: A model instance which contains weights and forward computation. model: A model instance which contains weights and forward computation.
preprocessor: An optional callable to preprocess the inputs.
inference_step: An optional callable to forward-pass the model. If not inference_step: An optional callable to forward-pass the model. If not
specified, it creates a parital function with `model` as an required specified, it creates a parital function with `model` as an required
kwarg. kwarg.
preprocessor: An optional callable to preprocess the inputs.
postprocessor: An optional callable to postprocess the model outputs. postprocessor: An optional callable to postprocess the model outputs.
""" """
super().__init__(name=None) super().__init__(name=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment