"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "6fcc2e715059ee99ca580d19b505800f1d746c67"
Commit 824e4935 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add dtype parameter to VAE object.

parent 32b7e7e7
...@@ -151,7 +151,7 @@ class CLIP: ...@@ -151,7 +151,7 @@ class CLIP:
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None): def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
...@@ -188,7 +188,9 @@ class VAE: ...@@ -188,7 +188,9 @@ class VAE:
device = model_management.vae_device() device = model_management.vae_device()
self.device = device self.device = device
offload_device = model_management.vae_offload_device() offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype() if dtype is None:
dtype = model_management.vae_dtype()
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype) self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device() self.output_device = model_management.intermediate_device()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment