Commit 15adc369 authored by comfyanonymous's avatar comfyanonymous
Browse files

Move beta_schedule to model_config and allow disabling unet creation.

parent 968078b1
......@@ -19,8 +19,9 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config, device=device)
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
......
......@@ -33,6 +33,7 @@ class BASE:
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
beta_schedule = "linear"
@classmethod
def matches(s, unet_config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment