Commit 21f04fe6 authored by comfyanonymous's avatar comfyanonymous
Browse files

Disable default weight values in unet conv2d for faster loading.

parent 9d54066e
...@@ -16,7 +16,7 @@ import numpy as np ...@@ -16,7 +16,7 @@ import numpy as np
from einops import repeat from einops import repeat
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
import comfy.ops
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear": if schedule == "linear":
...@@ -233,7 +233,7 @@ def conv_nd(dims, *args, **kwargs): ...@@ -233,7 +233,7 @@ def conv_nd(dims, *args, **kwargs):
if dims == 1: if dims == 1:
return nn.Conv1d(*args, **kwargs) return nn.Conv1d(*args, **kwargs)
elif dims == 2: elif dims == 2:
return nn.Conv2d(*args, **kwargs) return comfy.ops.Conv2d(*args, **kwargs)
elif dims == 3: elif dims == 3:
return nn.Conv3d(*args, **kwargs) return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
...@@ -243,7 +243,7 @@ def linear(*args, **kwargs): ...@@ -243,7 +243,7 @@ def linear(*args, **kwargs):
""" """
Create a linear module. Create a linear module.
""" """
return nn.Linear(*args, **kwargs) return comfy.ops.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs): def avg_pool_nd(dims, *args, **kwargs):
......
...@@ -15,3 +15,7 @@ class Linear(torch.nn.Module): ...@@ -15,3 +15,7 @@ class Linear(torch.nn.Module):
def forward(self, input): def forward(self, input):
return torch.nn.functional.linear(input, self.weight, self.bias) return torch.nn.functional.linear(input, self.weight, self.bias)
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment