"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "4a28a2f093de06d8a5735e2f435d35a2d7406391"
Unverified Commit 58f37715 authored by Joseph Turian's avatar Joseph Turian Committed by GitHub
Browse files

Add optional precision-preserving preprocessing for...


Add optional precision-preserving preprocessing for examples/unconditional_image_generation/train_unconditional.py (#12596)

* Add optional precision-preserving preprocessing

* Document decoder caveat for precision flag

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 6198f8a1
...@@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways: ...@@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways:
- you can either provide your own folder as `--train_data_dir` - you can either provide your own folder as `--train_data_dir`
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument. - or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
Below, we explain both in more detail. Below, we explain both in more detail.
#### Provide the dataset as a folder #### Provide the dataset as a folder
......
...@@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): ...@@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
return res.expand(broadcast_shape) return res.expand(broadcast_shape)
def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
"""
Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
"""
if tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
channels = tensor.shape[0]
if channels == 3:
return tensor
if channels == 1:
return tensor.repeat(3, 1, 1)
if channels == 2:
return torch.cat([tensor, tensor[:1]], dim=0)
if channels > 3:
return tensor[:3]
raise ValueError(f"Unsupported number of channels: {channels}")
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument( parser.add_argument(
...@@ -260,6 +278,11 @@ def parse_args(): ...@@ -260,6 +278,11 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
) )
parser.add_argument(
"--preserve_input_precision",
action="store_true",
help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
)
args = parser.parse_args() args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
...@@ -453,19 +476,41 @@ def main(args): ...@@ -453,19 +476,41 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation. # Preprocessing the datasets and DataLoaders creation.
spatial_augmentations = [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
]
augmentations = transforms.Compose( augmentations = transforms.Compose(
[ spatial_augmentations
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + [
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
] ]
) )
precision_augmentations = transforms.Compose(
[
transforms.PILToTensor(),
transforms.Lambda(_ensure_three_channels),
transforms.ConvertImageDtype(torch.float32),
]
+ spatial_augmentations
+ [transforms.Normalize([0.5], [0.5])]
)
def transform_images(examples): def transform_images(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]] processed = []
return {"input": images} for image in examples["image"]:
if not args.preserve_input_precision:
processed.append(augmentations(image.convert("RGB")))
else:
precise_image = image
if precise_image.mode == "P":
precise_image = precise_image.convert("RGB")
processed.append(precision_augmentations(precise_image))
return {"input": processed}
logger.info(f"Dataset size: {len(dataset)}") logger.info(f"Dataset size: {len(dataset)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment