Unverified Commit b9eea06e authored by Kane Wallmann's avatar Kane Wallmann Committed by GitHub
Browse files

Include CLIPTextModel parameters in conversion (#695)

parent 08d4fb6e
......@@ -595,6 +595,22 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
return hf_model
def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
keys = list(checkpoint.keys())
text_model_dict = {}
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
text_model.load_state_dict(text_model_dict)
return text_model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......@@ -668,7 +684,7 @@ if __name__ == "__main__":
# Convert the text model.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenCLIPEmbedder":
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment