Unverified Commit effe9d66 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[FlaxStableDiffusionPipeline] fix bug when nsfw is detected (#832)

fix nsfw bug
parent 0679d090
...@@ -291,7 +291,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -291,7 +291,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
# block images # block images
if any(has_nsfw_concept): if any(has_nsfw_concept):
for i, is_nsfw in enumerate(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept):
images[i] = np.asarray(images_uint8_casted[i]) if is_nsfw:
images[i] = np.asarray(images_uint8_casted[i])
images = images.reshape(num_devices, batch_size, height, width, 3) images = images.reshape(num_devices, batch_size, height, width, 3)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment