Unverified Commit f67639b0 authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

add post init for safty checker (#12794)



* add post init for safty checker
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* check transformers version before post init
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* Apply style fixes

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 5a743197
......@@ -17,7 +17,7 @@ import torch
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
from ...utils import logging
from ...utils import is_transformers_version, logging
logger = logging.get_logger(__name__)
......@@ -46,6 +46,9 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
# Model requires post_init after transformers v4.57.3
if is_transformers_version(">", "4.57.3"):
self.post_init()
@torch.no_grad()
def forward(self, clip_input, images):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment