torchrec.py 3.81 KB
Newer Older
1
2
3
from functools import partial

import torch
4
5
6
7
from torchrec.models import deepfm, dlrm
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
8

9
from ..registry import model_zoo
10

11
12
BATCH = 2
SHAPE = 10
13
14
15
16
17
18


def gen_kt():
    KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
    return KT

19

20
# KeyedJaggedTensor
21
def gen_kjt():
22
23
24
    KJT = KeyedJaggedTensor.from_offsets_sync(
        keys=["f1", "f2"], values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), offsets=torch.tensor([0, 2, 4, 6, 8])
    )
25
26
    return KJT

27

28
data_gen_fn = lambda: dict(features=torch.rand((BATCH, SHAPE)))
29
30


31
32
33
34
35
36
37
38
39
def interaction_arch_data_gen_fn():
    KT = gen_kt()
    return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT)


def simple_dfm_data_gen_fn():
    KJT = gen_kjt()
    return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT)

40

41
42
43
44
45
46
47
48
49
50
51
52
53
def sparse_arch_data_gen_fn():
    KJT = gen_kjt()
    return dict(features=KJT)


def output_transform_fn(x):
    if isinstance(x, KeyedTensor):
        output = dict()
        for key in x.keys():
            output[key] = x[key]
        return output
    else:
        return dict(output=x)
54
55


56
57
58
59
def get_ebc():
    # EmbeddingBagCollection
    eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
    eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
60
    return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device("cpu"))
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80


def sparse_arch_model_fn():
    ebc = get_ebc()
    return deepfm.SparseArch(ebc)


def simple_deep_fmnn_model_fn():
    ebc = get_ebc()
    return deepfm.SimpleDeepFMNN(SHAPE, ebc, SHAPE, SHAPE)


def dlrm_model_fn():
    ebc = get_ebc()
    return dlrm.DLRM(ebc, SHAPE, [SHAPE, SHAPE], [5, 1])


def dlrm_sparsearch_model_fn():
    ebc = get_ebc()
    return dlrm.SparseArch(ebc)
81
82


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
model_zoo.register(
    name="deepfm_densearch",
    model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE),
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)

model_zoo.register(
    name="deepfm_interactionarch",
    model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE),
    data_gen_fn=interaction_arch_data_gen_fn,
    output_transform_fn=output_transform_fn,
)

model_zoo.register(
    name="deepfm_overarch",
    model_fn=partial(deepfm.OverArch, SHAPE),
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)

model_zoo.register(
    name="deepfm_simpledeepfmnn",
    model_fn=simple_deep_fmnn_model_fn,
    data_gen_fn=simple_dfm_data_gen_fn,
    output_transform_fn=output_transform_fn,
)

model_zoo.register(
    name="deepfm_sparsearch",
    model_fn=sparse_arch_model_fn,
    data_gen_fn=sparse_arch_data_gen_fn,
    output_transform_fn=output_transform_fn,
)

model_zoo.register(
    name="dlrm", model_fn=dlrm_model_fn, data_gen_fn=simple_dfm_data_gen_fn, output_transform_fn=output_transform_fn
)

model_zoo.register(
    name="dlrm_densearch",
    model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]),
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)

model_zoo.register(
    name="dlrm_interactionarch",
    model_fn=partial(dlrm.InteractionArch, 2),
    data_gen_fn=interaction_arch_data_gen_fn,
    output_transform_fn=output_transform_fn,
)

model_zoo.register(
    name="dlrm_overarch",
    model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]),
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)

model_zoo.register(
    name="dlrm_sparsearch",
    model_fn=dlrm_sparsearch_model_fn,
    data_gen_fn=sparse_arch_data_gen_fn,
    output_transform_fn=output_transform_fn,
)