local_data_manager.py 3.43 KB
Newer Older
LiangLiu's avatar
LiangLiu committed
1
2
3
4
5
6
7
8
9
10
import asyncio
import os

from loguru import logger

from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.data_manager import BaseDataManager


class LocalDataManager(BaseDataManager):
LiangLiu's avatar
LiangLiu committed
11
12
    def __init__(self, local_dir, template_dir):
        super().__init__()
LiangLiu's avatar
LiangLiu committed
13
14
15
16
        self.local_dir = local_dir
        self.name = "local"
        if not os.path.exists(self.local_dir):
            os.makedirs(self.local_dir)
LiangLiu's avatar
LiangLiu committed
17
18
19
20
21
22
23
24
25
        if template_dir:
            self.template_images_dir = os.path.join(template_dir, "images")
            self.template_audios_dir = os.path.join(template_dir, "audios")
            self.template_videos_dir = os.path.join(template_dir, "videos")
            self.template_tasks_dir = os.path.join(template_dir, "tasks")
            assert os.path.exists(self.template_images_dir), f"{self.template_images_dir} not exists!"
            assert os.path.exists(self.template_audios_dir), f"{self.template_audios_dir} not exists!"
            assert os.path.exists(self.template_videos_dir), f"{self.template_videos_dir} not exists!"
            assert os.path.exists(self.template_tasks_dir), f"{self.template_tasks_dir} not exists!"
LiangLiu's avatar
LiangLiu committed
26
27

    @class_try_catch_async
LiangLiu's avatar
LiangLiu committed
28
29
    async def save_bytes(self, bytes_data, filename, abs_path=None):
        out_path = self.fmt_path(self.local_dir, filename, abs_path)
LiangLiu's avatar
LiangLiu committed
30
31
32
33
34
        with open(out_path, "wb") as fout:
            fout.write(bytes_data)
            return True

    @class_try_catch_async
LiangLiu's avatar
LiangLiu committed
35
36
    async def load_bytes(self, filename, abs_path=None):
        inp_path = self.fmt_path(self.local_dir, filename, abs_path)
LiangLiu's avatar
LiangLiu committed
37
38
39
40
        with open(inp_path, "rb") as fin:
            return fin.read()

    @class_try_catch_async
LiangLiu's avatar
LiangLiu committed
41
42
    async def delete_bytes(self, filename, abs_path=None):
        inp_path = self.fmt_path(self.local_dir, filename, abs_path)
LiangLiu's avatar
LiangLiu committed
43
44
45
46
        os.remove(inp_path)
        logger.info(f"deleted local file {filename}")
        return True

LiangLiu's avatar
LiangLiu committed
47
48
49
50
51
52
53
54
55
56
    @class_try_catch_async
    async def file_exists(self, filename, abs_path=None):
        filename = self.fmt_path(self.local_dir, filename, abs_path)
        return os.path.exists(filename)

    @class_try_catch_async
    async def list_files(self, base_dir=None):
        prefix = base_dir if base_dir else self.local_dir
        return os.listdir(prefix)

LiangLiu's avatar
LiangLiu committed
57
58
59
60
61

async def test():
    import torch
    from PIL import Image

LiangLiu's avatar
LiangLiu committed
62
    m = LocalDataManager("/data/nvme1/liuliang1/lightx2v/local_data", None)
LiangLiu's avatar
LiangLiu committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    await m.init()

    img = Image.open("/data/nvme1/liuliang1/lightx2v/assets/img_lightx2v.png")
    tensor = torch.Tensor([233, 456, 789]).to(dtype=torch.bfloat16, device="cuda:0")

    await m.save_image(img, "test_img.png")
    print(await m.load_image("test_img.png"))

    await m.save_tensor(tensor, "test_tensor.pt")
    print(await m.load_tensor("test_tensor.pt", "cuda:0"))

    await m.save_object(
        {
            "images": [img, img],
            "tensor": tensor,
            "list": [
                [2, 0, 5, 5],
                {
                    "1": "hello world",
                    "2": "world",
                    "3": img,
                    "t": tensor,
                },
                "0609",
            ],
        },
        "test_object.json",
    )
    print(await m.load_object("test_object.json", "cuda:0"))

    await m.get_delete_func("OBJECT")("test_object.json")
    await m.get_delete_func("TENSOR")("test_tensor.pt")
    await m.get_delete_func("IMAGE")("test_img.png")


if __name__ == "__main__":
    asyncio.run(test())