custom_process_mixin.py 1.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from collections.abc import Callable

import torch


class CustomProcessMixin:
    """
    Mixin class for all stages in the Omni model.
    """

    def set_custom_preprocess(self, preprocess_fn: Callable) -> None:
        """
        Set a preprocess function for the stage.
        Args:
            preprocess_fn: The preprocess function to register.
        """
        self.preprocess = preprocess_fn

    def set_custom_postprocess(self, postprocess_fn: Callable) -> None:
        """
        Set a postprocess function for the stage.
        Args:
            postprocess_fn: The postprocess function to register.
        """
        self.postprocess = postprocess_fn

    def preprocess(
        self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **input_dict: object
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Process the input_ids and input_embeds for the given input_dict.
        Returns the processed input_ids, input_embeds, and the input_dict.
        If the stage don't applicable, return the original input_ids, input_embeds, and an empty dict.
        """
        raise NotImplementedError("Preprocess is not implemented for this stage.")

    def postprocess(self, model_output, **info_dict: object):
        """
        Postprocess the model output.
        Returns the postprocessed model output and the save dictionary.
        Args:
            model_output: The model output to postprocess.
        """
        raise NotImplementedError("Postprocess is not implemented for this stage.")