tensor.py 497 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from lightx2v.utils.registry_factory import TENSOR_REGISTER


@TENSOR_REGISTER("Default")
class DefaultTensor:
    def __init__(self, tensor_name):
        self.tensor_name = tensor_name

    def load(self, weight_dict):
        self.tensor = weight_dict[self.tensor_name]

    def to_cpu(self, non_blocking=False):
        self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)

    def to_cuda(self, non_blocking=False):
        self.tensor = self.tensor.cuda(non_blocking=non_blocking)