"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "457abdf2cf31956a15df7233187b0b358307c7d1"
Unverified Commit 5d2d2398 authored by Thanh Le's avatar Thanh Le Committed by GitHub
Browse files

Fix inconsistent random transform in instruct pix2pix (#10698)

* Update train_instruct_pix2pix.py

Fix inconsistent random transform in instruct_pix2pix

* Update train_instruct_pix2pix_sdxl.py
parent 1ae9b059
...@@ -695,7 +695,7 @@ def main(): ...@@ -695,7 +695,7 @@ def main():
) )
# We need to ensure that the original and the edited images undergo the same # We need to ensure that the original and the edited images undergo the same
# augmentation transforms. # augmentation transforms.
images = np.concatenate([original_images, edited_images]) images = np.stack([original_images, edited_images])
images = torch.tensor(images) images = torch.tensor(images)
images = 2 * (images / 255) - 1 images = 2 * (images / 255) - 1
return train_transforms(images) return train_transforms(images)
...@@ -706,7 +706,7 @@ def main(): ...@@ -706,7 +706,7 @@ def main():
# Since the original and edited images were concatenated before # Since the original and edited images were concatenated before
# applying the transformations, we need to separate them and reshape # applying the transformations, we need to separate them and reshape
# them accordingly. # them accordingly.
original_images, edited_images = preprocessed_images.chunk(2) original_images, edited_images = preprocessed_images
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution) original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution) edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
......
...@@ -766,7 +766,7 @@ def main(): ...@@ -766,7 +766,7 @@ def main():
) )
# We need to ensure that the original and the edited images undergo the same # We need to ensure that the original and the edited images undergo the same
# augmentation transforms. # augmentation transforms.
images = np.concatenate([original_images, edited_images]) images = np.stack([original_images, edited_images])
images = torch.tensor(images) images = torch.tensor(images)
images = 2 * (images / 255) - 1 images = 2 * (images / 255) - 1
return train_transforms(images) return train_transforms(images)
...@@ -906,7 +906,7 @@ def main(): ...@@ -906,7 +906,7 @@ def main():
# Since the original and edited images were concatenated before # Since the original and edited images were concatenated before
# applying the transformations, we need to separate them and reshape # applying the transformations, we need to separate them and reshape
# them accordingly. # them accordingly.
original_images, edited_images = preprocessed_images.chunk(2) original_images, edited_images = preprocessed_images
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution) original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution) edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment