model.py 1.59 KB
Newer Older
1
import numpy
zhyncs's avatar
zhyncs committed
2
3
4
import triton_python_backend_utils as pb_utils
from pydantic import BaseModel

5
6
7
8
9
10
import sglang as sgl
from sglang import function, set_default_backend
from sglang.srt.constrained import build_regex_from_object

sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))

zhyncs's avatar
zhyncs committed
11

12
13
14
15
16
class Character(BaseModel):
    name: str
    eye_color: str
    house: str

zhyncs's avatar
zhyncs committed
17

18
19
20
21
22
23
@function
def character_gen(s, name):
    s += (
        name
        + " is a character in Harry Potter. Please fill in the following information about this character.\n"
    )
zhyncs's avatar
zhyncs committed
24
25
26
    s += sgl.gen(
        "json_output", max_tokens=256, regex=build_regex_from_object(Character)
    )
27
28
29
30
31


class TritonPythonModel:
    def initialize(self, args):
        print("Initialized.")
zhyncs's avatar
zhyncs committed
32

33
34
35
36
37
38
39
    def execute(self, requests):
        responses = []
        for request in requests:
            tensor_in = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT")
            if tensor_in is None:
                return pb_utils.InferenceResponse(output_tensors=[])

zhyncs's avatar
zhyncs committed
40
41
42
43
44
45
            input_list_names = [
                i.decode("utf-8") if isinstance(i, bytes) else i
                for i in tensor_in.as_numpy().tolist()
            ]

            input_list_dicts = [{"name": i} for i in input_list_names]
46
47
48
49

            states = character_gen.run_batch(input_list_dicts)
            character_strs = [state.text() for state in states]

zhyncs's avatar
zhyncs committed
50
51
52
            tensor_out = pb_utils.Tensor(
                "OUTPUT_TEXT", numpy.array(character_strs, dtype=object)
            )
53

zhyncs's avatar
zhyncs committed
54
55
            responses.append(pb_utils.InferenceResponse(output_tensors=[tensor_out]))
        return responses