Commit 0efac0aa authored by Patrick von Platen's avatar Patrick von Platen
Browse files

remove einops fully

parent d74b804d
...@@ -9,14 +9,6 @@ from ..configuration_utils import ConfigMixin ...@@ -9,14 +9,6 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
# try:
# import einops
# from einops.layers.torch import Rearrange
# except:
# print("Einops is not installed")
# pass
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
...@@ -198,7 +190,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -198,7 +190,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
""" """
# x = einops.rearrange(x, "b h t -> b t h")
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
t = self.time_mlp(timesteps) t = self.time_mlp(timesteps)
...@@ -221,7 +212,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -221,7 +212,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x = self.final_conv(x) x = self.final_conv(x)
# x = einops.rearrange(x, "b t h -> b h t")
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
return x return x
...@@ -279,7 +269,6 @@ class TemporalValue(nn.Module): ...@@ -279,7 +269,6 @@ class TemporalValue(nn.Module):
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
""" """
# x = einops.rearrange(x, "b h t -> b t h")
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
t = self.time_mlp(time) t = self.time_mlp(time)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment