serial_utils.py 1.54 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import pickle
4
5
6
7
8
9
from typing import Any

import torch
from msgspec import msgpack

CUSTOM_TYPE_CODE_PICKLE = 1
10
11
12
13


class PickleEncoder:

14
    def encode(self, obj: Any):
15
16
        return pickle.dumps(obj)

17
    def decode(self, data: Any):
18
        return pickle.loads(data)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58


class MsgpackEncoder:
    """Encoder with custom torch tensor serialization."""

    def __init__(self):
        self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)

    def encode(self, obj: Any) -> bytes:
        return self.encoder.encode(obj)

    def encode_into(self, obj: Any, buf: bytearray) -> None:
        self.encoder.encode_into(obj, buf)


class MsgpackDecoder:
    """Decoder with custom torch tensor serialization."""

    def __init__(self, t: Any):
        self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook)

    def decode(self, obj: Any):
        return self.decoder.decode(obj)


def custom_enc_hook(obj: Any) -> Any:
    if isinstance(obj, torch.Tensor):
        # NOTE(rob): it is fastest to use numpy + pickle
        # when serializing torch tensors.
        # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
        return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy()))

    raise NotImplementedError(f"Objects of type {type(obj)} are not supported")


def custom_ext_hook(code: int, data: memoryview) -> Any:
    if code == CUSTOM_TYPE_CODE_PICKLE:
        return torch.from_numpy(pickle.loads(data))

    raise NotImplementedError(f"Extension type code {code} is not supported")