embedding_weight.py 1.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from abc import ABCMeta

import torch
import torch.nn.functional as F

from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER


class EmbeddingWeightTemplate(metaclass=ABCMeta):
    def __init__(self, weight_name, lazy_load=False, lazy_load_file=None):
        self.weight_name = weight_name
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
        self.config = {}

    def load(self, weight_dict):
        if not self.lazy_load:
            if self.weight_name is not None:
                self.weight = weight_dict[self.weight_name]
                self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
            else:
                self.weight = None

            del weight_dict[self.weight_name]

    def to_cpu(self, non_blocking=False):
        if hasattr(self, "pinned_weight"):
            self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
        else:
            self.weight = self.weight.to("cpu", non_blocking=non_blocking)

    def to_cuda(self, non_blocking=False):
        self.weight = self.weight.cuda(non_blocking=non_blocking)


@EMBEDDING_WEIGHT_REGISTER("Default")
class EmbeddingWeight(EmbeddingWeightTemplate):
    def __init__(self, weight_name=None, lazy_load=False, lazy_load_file=None):
        super().__init__(weight_name, lazy_load, lazy_load_file)

    def apply(self, input_indices):
        output = F.embedding(input=input_indices, weight=self.weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)

        return output