"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6a05b274cc503276a4c1ac22a451df9184a9f761"
Unverified Commit ce31f83d authored by Ryan Russell's avatar Ryan Russell Committed by GitHub
Browse files

refactor: pipelines readability improvements (#622)



* refactor: pipelines readability improvements
Signed-off-by: default avatarRyan Russell <git@ryanrussell.org>

* docs: remove todo comment from flax pipeline
Signed-off-by: default avatarRyan Russell <git@ryanrussell.org>
Signed-off-by: default avatarRyan Russell <git@ryanrussell.org>
parent b00382e2
...@@ -34,7 +34,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -34,7 +34,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]): safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
...@@ -149,7 +149,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -149,7 +149,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings]) context = jnp.concatenate([uncond_embeddings, text_embeddings])
# TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline
latents_shape = ( latents_shape = (
batch_size, batch_size,
self.unet.in_channels, self.unet.in_channels,
...@@ -206,9 +205,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -206,9 +205,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
# image = jnp.asarray(image).transpose(0, 2, 3, 1) # image = jnp.asarray(image).transpose(0, 2, 3, 1)
# run safety checker # run safety checker
# TODO: check when flax safety checker gets merged into main # TODO: check when flax safety checker gets merged into main
# safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") # safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
# image, has_nsfw_concept = self.safety_checker( # image, has_nsfw_concept = self.safety_checker(
# images=image, clip_input=safety_cheker_input.pixel_values, params=params["safety_params"] # images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
# ) # )
has_nsfw_concept = False has_nsfw_concept = False
......
...@@ -36,7 +36,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -36,7 +36,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
...@@ -278,8 +278,8 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -278,8 +278,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker # run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
......
...@@ -48,7 +48,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -48,7 +48,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
...@@ -288,8 +288,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -288,8 +288,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker # run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
......
...@@ -66,7 +66,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -66,7 +66,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
...@@ -328,8 +328,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -328,8 +328,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker # run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
......
...@@ -48,20 +48,20 @@ class StableDiffusionSafetyChecker(PreTrainedModel): ...@@ -48,20 +48,20 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
# at the cost of increasing the possibility of filtering benign images # at the cost of increasing the possibility of filtering benign images
adjustment = 0.0 adjustment = 0.0
for concet_idx in range(len(special_cos_dist[0])): for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concet_idx] concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concet_idx].item() concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concet_idx] > 0: if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]}) result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01 adjustment = 0.01
for concet_idx in range(len(cos_dist[0])): for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concet_idx] concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concet_idx].item() concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concet_idx] > 0: if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concet_idx) result_img["bad_concepts"].append(concept_idx)
result.append(result_img) result.append(result_img)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment