Unverified Commit e2d6d5ce authored by Michal Jamroz's avatar Michal Jamroz Committed by GitHub
Browse files

Normalize only if needed (#26049)



* Normalize only if needed

* Update examples/pytorch/image-classification/run_image_classification.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* if else in one line

* within block

* one more place, sorry for mess

* import order

* Update examples/pytorch/image-classification/run_image_classification.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/pytorch/image-classification/run_image_classification_no_trainer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 576e2823
......@@ -28,6 +28,7 @@ from PIL import Image
from torchvision.transforms import (
CenterCrop,
Compose,
Lambda,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
......@@ -325,7 +326,11 @@ def main():
size = image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
_train_transforms = Compose(
[
RandomResizedCrop(size),
......
......@@ -32,6 +32,7 @@ from torch.utils.data import DataLoader
from torchvision.transforms import (
CenterCrop,
Compose,
Lambda,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
......@@ -331,7 +332,11 @@ def main():
size = image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
train_transforms = Compose(
[
RandomResizedCrop(size),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment