Commit 824e4935 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add dtype parameter to VAE object.

parent 32b7e7e7
......@@ -151,7 +151,7 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
def __init__(self, sd=None, device=None, config=None):
def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
......@@ -188,7 +188,9 @@ class VAE:
device = model_management.vae_device()
self.device = device
offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype()
if dtype is None:
dtype = model_management.vae_dtype()
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment