serial_utils.py 1.78 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import pickle
4
from types import FunctionType
5
from typing import Any, Optional
6

7
import cloudpickle
8
9
10
import torch
from msgspec import msgpack

11
12
CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2
13
CUSTOM_TYPE_CLOUDPICKLE = 3
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


class MsgpackEncoder:
    """Encoder with custom torch tensor serialization."""

    def __init__(self):
        self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)

    def encode(self, obj: Any) -> bytes:
        return self.encoder.encode(obj)

    def encode_into(self, obj: Any, buf: bytearray) -> None:
        self.encoder.encode_into(obj, buf)


class MsgpackDecoder:
    """Decoder with custom torch tensor serialization."""

32
33
34
    def __init__(self, t: Optional[Any] = None):
        args = () if t is None else (t, )
        self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
35
36
37
38
39
40
41
42
43
44

    def decode(self, obj: Any):
        return self.decoder.decode(obj)


def custom_enc_hook(obj: Any) -> Any:
    if isinstance(obj, torch.Tensor):
        # NOTE(rob): it is fastest to use numpy + pickle
        # when serializing torch tensors.
        # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
45
        return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
46

47
48
49
    if isinstance(obj, FunctionType):
        return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

50
    return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
51
52
53


def custom_ext_hook(code: int, data: memoryview) -> Any:
54
    if code == CUSTOM_TYPE_TENSOR:
55
        return torch.from_numpy(pickle.loads(data))
56
57
    if code == CUSTOM_TYPE_PICKLE:
        return pickle.loads(data)
58
59
    if code == CUSTOM_TYPE_CLOUDPICKLE:
        return cloudpickle.loads(data)
60
61

    raise NotImplementedError(f"Extension type code {code} is not supported")