encoder_decoder.py 1.48 KB
Newer Older
Neelay Shah's avatar
Neelay Shah committed
1
import numpy
2

Neelay Shah's avatar
Neelay Shah committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from triton_distributed.worker import Operator, RemoteInferenceRequest, RemoteOperator


class EncodeDecodeOperator(Operator):
    def __init__(
        self,
        name,
        version,
        triton_core,
        request_plane,
        data_plane,
        parameters,
        repository,
        logger,
    ):
        self._encoder = RemoteOperator("encoder", request_plane, data_plane)
        self._decoder = RemoteOperator("decoder", request_plane, data_plane)
        self._logger = logger

    async def execute(self, requests: list[RemoteInferenceRequest]):
        self._logger.info("got request!")
        for request in requests:
            encoded_responses = await self._encoder.async_infer(
                inputs={"input": request.inputs["input"]}
            )

            async for encoded_response in encoded_responses:
                input_copies = int(
                    numpy.from_dlpack(encoded_response.outputs["input_copies"])
                )
                decoded_responses = await self._decoder.async_infer(
                    inputs={"input": encoded_response.outputs["output"]},
                    parameters={"input_copies": input_copies},
                )

                async for decoded_response in decoded_responses:
                    await request.response_sender().send(
                        final=True,
                        outputs={"output": decoded_response.outputs["output"]},
                    )
                    del decoded_response