"docs/vscode:/vscode.git/clone" did not exist on "e6110f68569c7b620306e678c3a3d9eee1a293e2"
unet_rl.py 6.64 KB
Newer Older
1
2
3
4
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py

import torch
import torch.nn as nn
5

Nathan Lambert's avatar
Nathan Lambert committed
6
7
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
Patrick von Platen's avatar
Patrick von Platen committed
8
from .embeddings import get_timestep_embedding
9
from .resnet import Downsample, ResidualTemporalBlock, Upsample
Nathan Lambert's avatar
Nathan Lambert committed
10
11


12
13
14
15
16
17
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
Patrick von Platen's avatar
Patrick von Platen committed
18
        return get_timestep_embedding(x, self.dim)
19

20

Patrick von Platen's avatar
Patrick von Platen committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class RearrangeDim(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, tensor):
        if len(tensor.shape) == 2:
            return tensor[:, :, None]
        if len(tensor.shape) == 3:
            return tensor[:, :, None, :]
        elif len(tensor.shape) == 4:
            return tensor[:, :, 0, :]
        else:
            raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


36
class Conv1dBlock(nn.Module):
37
38
39
    """
    Conv1d --> GroupNorm --> Mish
    """
40
41
42
43
44
45

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Patrick von Platen's avatar
Patrick von Platen committed
46
47
            RearrangeDim(),
            #            Rearrange("batch channels horizon -> batch channels 1 horizon"),
48
            nn.GroupNorm(n_groups, out_channels),
Patrick von Platen's avatar
Patrick von Platen committed
49
50
            RearrangeDim(),
            #            Rearrange("batch channels 1 horizon -> batch channels horizon"),
51
52
53
54
55
56
57
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


anton-l's avatar
anton-l committed
58
class TemporalUNet(ModelMixin, ConfigMixin):  # (nn.Module):
59
    def __init__(
60
        self,
Patrick von Platen's avatar
Patrick von Platen committed
61
62
63
        training_horizon=128,
        transition_dim=14,
        cond_dim=3,
64
65
66
        predict_epsilon=False,
        clip_denoised=True,
        dim=32,
Patrick von Platen's avatar
Patrick von Platen committed
67
        dim_mults=(1, 4, 8),
68
69
70
    ):
        super().__init__()

Nathan Lambert's avatar
Nathan Lambert committed
71
72
73
74
75
        self.transition_dim = transition_dim
        self.cond_dim = cond_dim
        self.predict_epsilon = predict_epsilon
        self.clip_denoised = clip_denoised

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        time_dim = dim
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, dim * 4),
            nn.Mish(),
            nn.Linear(dim * 4, dim),
        )

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

94
95
96
            self.downs.append(
                nn.ModuleList(
                    [
Nathan Lambert's avatar
Nathan Lambert committed
97
98
                        ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
                        ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
99
                        Downsample(dim_out, use_conv=True, dims=1) if not is_last else nn.Identity(),
100
101
102
                    ]
                )
            )
103
104

            if not is_last:
Nathan Lambert's avatar
Nathan Lambert committed
105
                training_horizon = training_horizon // 2
106
107

        mid_dim = dims[-1]
Nathan Lambert's avatar
Nathan Lambert committed
108
109
        self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
        self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
110
111
112
113

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

114
115
116
            self.ups.append(
                nn.ModuleList(
                    [
Nathan Lambert's avatar
Nathan Lambert committed
117
118
                        ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
                        ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
119
                        Upsample(dim_in, use_conv_transpose=True, dims=1) if not is_last else nn.Identity(),
120
121
122
                    ]
                )
            )
123
124

            if not is_last:
Nathan Lambert's avatar
Nathan Lambert committed
125
                training_horizon = training_horizon * 2
126
127
128
129
130
131

        self.final_conv = nn.Sequential(
            Conv1dBlock(dim, dim, kernel_size=5),
            nn.Conv1d(dim, transition_dim, 1),
        )

Patrick von Platen's avatar
Patrick von Platen committed
132
    def forward(self, x, timesteps):
133
134
135
        """
        x : [ batch x horizon x transition ]
        """
136

Patrick von Platen's avatar
Patrick von Platen committed
137
        x = x.permute(0, 2, 1)
138

Patrick von Platen's avatar
Patrick von Platen committed
139
        t = self.time_mlp(timesteps)
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        h = []

        for resnet, resnet2, downsample in self.downs:
            x = resnet(x, t)
            x = resnet2(x, t)
            h.append(x)
            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_block2(x, t)

        for resnet, resnet2, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, t)
            x = resnet2(x, t)
            x = upsample(x)

        x = self.final_conv(x)

Patrick von Platen's avatar
Patrick von Platen committed
159
        x = x.permute(0, 2, 1)
160
161
162
        return x


163
class TemporalValue(nn.Module):
164
    def __init__(
165
166
167
168
169
170
171
172
        self,
        horizon,
        transition_dim,
        cond_dim,
        dim=32,
        time_dim=None,
        out_dim=1,
        dim_mults=(1, 2, 4, 8),
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    ):
        super().__init__()

        dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        time_dim = time_dim or dim
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, dim * 4),
            nn.Mish(),
            nn.Linear(dim * 4, dim),
        )

        self.blocks = nn.ModuleList([])

        print(in_out)
        for dim_in, dim_out in in_out:
191
192
193
194
195
196
197
198
199
            self.blocks.append(
                nn.ModuleList(
                    [
                        ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
                        ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
                        Downsample1d(dim_out),
                    ]
                )
            )
200
201
202
203
204
205
206
207
208
209
210
211

            horizon = horizon // 2

        fc_dim = dims[-1] * max(horizon, 1)

        self.final_block = nn.Sequential(
            nn.Linear(fc_dim + time_dim, fc_dim // 2),
            nn.Mish(),
            nn.Linear(fc_dim // 2, out_dim),
        )

    def forward(self, x, cond, time, *args):
212
213
214
        """
        x : [ batch x horizon x transition ]
        """
215

Patrick von Platen's avatar
Patrick von Platen committed
216
        x = x.permute(0, 2, 1)
217
218
219
220
221
222
223
224
225
226

        t = self.time_mlp(time)

        for resnet, resnet2, downsample in self.blocks:
            x = resnet(x, t)
            x = resnet2(x, t)
            x = downsample(x)

        x = x.view(len(x), -1)
        out = self.final_block(torch.cat([x, t], dim=-1))
227
        return out