Commit 430a8334 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix some potential issues.

parent 782a24fc
...@@ -92,8 +92,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): ...@@ -92,8 +92,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else: elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
else:
return None
clip = ClipVisionModel(json_config) clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd) m, u = clip.load_sd(sd)
if len(m) > 0: if len(m) > 0:
......
...@@ -434,10 +434,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o ...@@ -434,10 +434,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clip: if output_clip:
w = WeightsLoader() w = WeightsLoader()
clip_target = model_config.clip_target() clip_target = model_config.clip_target()
clip = CLIP(clip_target, embedding_directory=embedding_directory) if clip_target is not None:
w.cond_stage_model = clip.cond_stage_model clip = CLIP(clip_target, embedding_directory=embedding_directory)
sd = model_config.process_clip_state_dict(sd) w.cond_stage_model = clip.cond_stage_model
load_model_weights(w, sd) sd = model_config.process_clip_state_dict(sd)
load_model_weights(w, sd)
left_over = sd.keys() left_over = sd.keys()
if len(left_over) > 0: if len(left_over) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment