"megatron/inference/text_generation/generation.py" did not exist on "b7b2d6a91233ed8e6cd6492fd659dc481b5636b1"
pb.py 2.05 KB
Newer Older
sunzhq2's avatar
init  
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import Any, Generator, Iterable, List

from llm_perf.server import server_pb2, server_pb2_grpc


def deserialize_value(value: server_pb2.Value) -> Any:
    kind = value.WhichOneof("kind")
    if kind == "float_":
        return value.float_
    elif kind == "int64_":
        return value.int64_
    elif kind == "bytes_":
        return value.bytes_
    elif kind == "string_":
        return value.string_
    elif kind == "float_list":
        return [v for v in value.float_list.values]
    elif kind == "int64_list":
        return [v for v in value.int64_list.values]
    elif kind == "bytes_list":
        return [v for v in value.bytes_list.values]
    elif kind == "string_list":
        return [v for v in value.string_list.values]
    elif kind == "struct_":
        return {k: deserialize_value(v) for k, v in value.struct_.fields.items()}
    else:
        raise TypeError(f"Invalid type {type(value)}")


def serialize_value(value: Any) -> server_pb2.Value:
    if isinstance(value, float):
        return server_pb2.Value(float_=value)
    elif isinstance(value, int):
        return server_pb2.Value(int64_=value)
    elif isinstance(value, bytes):
        return server_pb2.Value(bytes_=value)
    elif isinstance(value, str):
        return server_pb2.Value(string_=value)
    elif isinstance(value, list):
        if isinstance(value[0], float):
            return server_pb2.Value(float_list=server_pb2.FloatList(values=value))
        elif isinstance(value[0], int):
            return server_pb2.Value(int64_list=server_pb2.Int64List(values=value))
        elif isinstance(value[0], bytes):
            return server_pb2.Value(bytes_list=server_pb2.BytesList(values=value))
        elif isinstance(value[0], str):
            return server_pb2.Value(string_list=server_pb2.StringList(values=value))
    elif isinstance(value, dict):
        return server_pb2.Value(
            struct_=server_pb2.Struct(
                fields={k: serialize_value(v) for k, v in value.items()}
            )
        )
    else:
        raise TypeError(f"Invalid type {type(value)}")