Unverified Commit b35bac4d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Support PyTorch 1.8] Remove inference mode (#707)

parent 688031c5
...@@ -79,7 +79,7 @@ class StableDiffusionSafetyChecker(PreTrainedModel): ...@@ -79,7 +79,7 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
return images, has_nsfw_concepts return images, has_nsfw_concepts
@torch.inference_mode() @torch.no_grad()
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
pooled_output = self.vision_model(clip_input)[1] # pooled_output pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output) image_embeds = self.visual_projection(pooled_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment