unet_rl.py 7.93 KB
Newer Older
1
2
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py

3
4
import math

5
6
import torch
import torch.nn as nn
7

anton-l's avatar
anton-l committed
8
9
10
11
12
13
try:
    import einops
    from einops.layers.torch import Rearrange
except:
    print("Einops is not installed")
    pass
14

Nathan Lambert's avatar
Nathan Lambert committed
15
16
17
18
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin


19
20
21
22
23
24
25
26
27
28
29
30
31
32
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

33

34
35
36
37
38
39
40
41
class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

42

43
44
45
46
47
48
49
50
class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)

51

52
class Conv1dBlock(nn.Module):
53
54
55
    """
    Conv1d --> GroupNorm --> Mish
    """
56
57
58
59
60
61

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
62
            Rearrange("batch channels horizon -> batch channels 1 horizon"),
63
            nn.GroupNorm(n_groups, out_channels),
64
            Rearrange("batch channels 1 horizon -> batch channels horizon"),
65
66
67
68
69
70
71
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


72
class ResidualTemporalBlock(nn.Module):
73
74
75
    def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
        super().__init__()

76
77
78
79
80
81
        self.blocks = nn.ModuleList(
            [
                Conv1dBlock(inp_channels, out_channels, kernel_size),
                Conv1dBlock(out_channels, out_channels, kernel_size),
            ]
        )
82
83
84
85

        self.time_mlp = nn.Sequential(
            nn.Mish(),
            nn.Linear(embed_dim, out_channels),
86
            Rearrange("batch t -> batch t 1"),
87
88
        )

89
90
91
        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
        )
92
93

    def forward(self, x, t):
94
95
96
97
98
99
        """
        x : [ batch_size x inp_channels x horizon ]
        t : [ batch_size x embed_dim ]
        returns:
        out : [ batch_size x out_channels x horizon ]
        """
100
101
102
103
104
        out = self.blocks[0](x) + self.time_mlp(t)
        out = self.blocks[1](out)
        return out + self.residual_conv(x)


anton-l's avatar
anton-l committed
105
class TemporalUNet(ModelMixin, ConfigMixin):  # (nn.Module):
106
    def __init__(
Nathan Lambert's avatar
Nathan Lambert committed
107
108
109
110
111
112
113
114
            self,
            training_horizon,
            transition_dim,
            cond_dim,
            predict_epsilon=False,
            clip_denoised=True,
            dim=32,
            dim_mults=(1, 2, 4, 8),
115
116
117
    ):
        super().__init__()

Nathan Lambert's avatar
Nathan Lambert committed
118
119
120
121
122
        self.transition_dim = transition_dim
        self.cond_dim = cond_dim
        self.predict_epsilon = predict_epsilon
        self.clip_denoised = clip_denoised

123
124
        dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
Nathan Lambert's avatar
Nathan Lambert committed
125
        # print(f'[ models/temporal ] Channel dimensions: {in_out}')
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

        time_dim = dim
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, dim * 4),
            nn.Mish(),
            nn.Linear(dim * 4, dim),
        )

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        print(in_out)
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

143
144
145
            self.downs.append(
                nn.ModuleList(
                    [
Nathan Lambert's avatar
Nathan Lambert committed
146
147
                        ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
                        ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
148
149
150
151
                        Downsample1d(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )
152
153

            if not is_last:
Nathan Lambert's avatar
Nathan Lambert committed
154
                training_horizon = training_horizon // 2
155
156

        mid_dim = dims[-1]
Nathan Lambert's avatar
Nathan Lambert committed
157
158
        self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
        self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
159
160
161
162

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

163
164
165
            self.ups.append(
                nn.ModuleList(
                    [
Nathan Lambert's avatar
Nathan Lambert committed
166
167
                        ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
                        ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
168
169
170
171
                        Upsample1d(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )
172
173

            if not is_last:
Nathan Lambert's avatar
Nathan Lambert committed
174
                training_horizon = training_horizon * 2
175
176
177
178
179
180
181

        self.final_conv = nn.Sequential(
            Conv1dBlock(dim, dim, kernel_size=5),
            nn.Conv1d(dim, transition_dim, 1),
        )

    def forward(self, x, cond, time):
182
183
184
        """
        x : [ batch x horizon x transition ]
        """
185

186
        x = einops.rearrange(x, "b h t -> b t h")
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

        t = self.time_mlp(time)
        h = []

        for resnet, resnet2, downsample in self.downs:
            x = resnet(x, t)
            x = resnet2(x, t)
            h.append(x)
            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_block2(x, t)

        for resnet, resnet2, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, t)
            x = resnet2(x, t)
            x = upsample(x)

        x = self.final_conv(x)

208
        x = einops.rearrange(x, "b t h -> b h t")
209
210
211
        return x


212
class TemporalValue(nn.Module):
213
    def __init__(
Nathan Lambert's avatar
Nathan Lambert committed
214
215
216
217
218
219
220
221
            self,
            horizon,
            transition_dim,
            cond_dim,
            dim=32,
            time_dim=None,
            out_dim=1,
            dim_mults=(1, 2, 4, 8),
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    ):
        super().__init__()

        dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        time_dim = time_dim or dim
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, dim * 4),
            nn.Mish(),
            nn.Linear(dim * 4, dim),
        )

        self.blocks = nn.ModuleList([])

        print(in_out)
        for dim_in, dim_out in in_out:
240
241
242
243
244
245
246
247
248
            self.blocks.append(
                nn.ModuleList(
                    [
                        ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
                        ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
                        Downsample1d(dim_out),
                    ]
                )
            )
249
250
251
252
253
254
255
256
257
258
259
260

            horizon = horizon // 2

        fc_dim = dims[-1] * max(horizon, 1)

        self.final_block = nn.Sequential(
            nn.Linear(fc_dim + time_dim, fc_dim // 2),
            nn.Mish(),
            nn.Linear(fc_dim // 2, out_dim),
        )

    def forward(self, x, cond, time, *args):
261
262
263
        """
        x : [ batch x horizon x transition ]
        """
264

265
        x = einops.rearrange(x, "b h t -> b t h")
266
267
268
269
270
271
272
273
274
275

        t = self.time_mlp(time)

        for resnet, resnet2, downsample in self.blocks:
            x = resnet(x, t)
            x = resnet2(x, t)
            x = downsample(x)

        x = x.view(len(x), -1)
        out = self.final_block(torch.cat([x, t], dim=-1))
276
        return out