"test/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "31a69c36c3b7292b43984c7b3b9b01603714749f"
Commit bb1f45d6 authored by comfyanonymous's avatar comfyanonymous
Browse files

Properly disable weight initialization in clip models.

parent 21f04fe6
...@@ -2,12 +2,14 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm ...@@ -2,12 +2,14 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm
from .utils import load_torch_file, transformers_convert from .utils import load_torch_file, transformers_convert
import os import os
import torch import torch
import comfy.ops
class ClipVisionModel(): class ClipVisionModel():
def __init__(self, json_config): def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config) config = CLIPVisionConfig.from_json_file(json_config)
with modeling_utils.no_init_weights(): with comfy.ops.use_comfy_ops():
self.model = CLIPVisionModelWithProjection(config) with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
self.processor = CLIPImageProcessor(crop_size=224, self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True, do_center_crop=True,
do_convert_rgb=True, do_convert_rgb=True,
......
import torch import torch
from contextlib import contextmanager
class Linear(torch.nn.Module): class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True, def __init__(self, in_features: int, out_features: int, bias: bool = True,
...@@ -19,3 +20,13 @@ class Linear(torch.nn.Module): ...@@ -19,3 +20,13 @@ class Linear(torch.nn.Module):
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d):
def reset_parameters(self): def reset_parameters(self):
return None return None
@contextmanager
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear
torch.nn.Linear = Linear
try:
yield
finally:
torch.nn.Linear = old_torch_nn_linear
import os import os
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
import comfy.ops
import torch import torch
import traceback import traceback
import zipfile import zipfile
...@@ -38,8 +39,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -38,8 +39,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None: if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config) config = CLIPTextConfig.from_json_file(textmodel_json_config)
with modeling_utils.no_init_weights(): with comfy.ops.use_comfy_ops():
self.transformer = CLIPTextModel(config) with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
self.device = device self.device = device
self.max_length = max_length self.max_length = max_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment