tensor.py 915 Bytes
Newer Older
Xinchi Huang's avatar
Xinchi Huang committed
1
import torch
2
3
4
5
6
7
8
9
10
11
from lightx2v.utils.registry_factory import TENSOR_REGISTER


@TENSOR_REGISTER("Default")
class DefaultTensor:
    def __init__(self, tensor_name):
        self.tensor_name = tensor_name

    def load(self, weight_dict):
        self.tensor = weight_dict[self.tensor_name]
Xinchi Huang's avatar
Xinchi Huang committed
12
        self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype)
13
14

    def to_cpu(self, non_blocking=False):
Xinchi Huang's avatar
Xinchi Huang committed
15
16
        # self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
        self.tensor = self.pinned_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
17
18
19

    def to_cuda(self, non_blocking=False):
        self.tensor = self.tensor.cuda(non_blocking=non_blocking)
20
21
22
23
24
25

    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.tensor_name] = self.tensor.cpu().detach().clone()
        return destination