"docs/source/vscode:/vscode.git/clone" did not exist on "26dd041c6e45379141302e2d293ab4cd9cf805d4"
Unverified Commit e2d6d5ce authored by Michal Jamroz's avatar Michal Jamroz Committed by GitHub
Browse files

Normalize only if needed (#26049)



* Normalize only if needed

* Update examples/pytorch/image-classification/run_image_classification.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* if else in one line

* within block

* one more place, sorry for mess

* import order

* Update examples/pytorch/image-classification/run_image_classification.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/pytorch/image-classification/run_image_classification_no_trainer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 576e2823
...@@ -28,6 +28,7 @@ from PIL import Image ...@@ -28,6 +28,7 @@ from PIL import Image
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
Compose, Compose,
Lambda,
Normalize, Normalize,
RandomHorizontalFlip, RandomHorizontalFlip,
RandomResizedCrop, RandomResizedCrop,
...@@ -325,7 +326,11 @@ def main(): ...@@ -325,7 +326,11 @@ def main():
size = image_processor.size["shortest_edge"] size = image_processor.size["shortest_edge"]
else: else:
size = (image_processor.size["height"], image_processor.size["width"]) size = (image_processor.size["height"], image_processor.size["width"])
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
_train_transforms = Compose( _train_transforms = Compose(
[ [
RandomResizedCrop(size), RandomResizedCrop(size),
......
...@@ -32,6 +32,7 @@ from torch.utils.data import DataLoader ...@@ -32,6 +32,7 @@ from torch.utils.data import DataLoader
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
Compose, Compose,
Lambda,
Normalize, Normalize,
RandomHorizontalFlip, RandomHorizontalFlip,
RandomResizedCrop, RandomResizedCrop,
...@@ -331,7 +332,11 @@ def main(): ...@@ -331,7 +332,11 @@ def main():
size = image_processor.size["shortest_edge"] size = image_processor.size["shortest_edge"]
else: else:
size = (image_processor.size["height"], image_processor.size["width"]) size = (image_processor.size["height"], image_processor.size["width"])
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
train_transforms = Compose( train_transforms = Compose(
[ [
RandomResizedCrop(size), RandomResizedCrop(size),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment