"vscode:/vscode.git/clone" did not exist on "34c199403be222da54f6eb55cb19ece97b9ee995"
Unverified Commit cc0a415e authored by Nathan Raw's avatar Nathan Raw Committed by GitHub
Browse files

update image classification example (#13824)

*  update image classification example

* 📌 update reqs
parent 6c088406
torch>=1.9.0 torch>=1.5.0
torchvision>=0.10.0 torchvision>=0.6.0
\ No newline at end of file datasets>=1.8.0
\ No newline at end of file
...@@ -56,7 +56,7 @@ logger = logging.getLogger(__name__) ...@@ -56,7 +56,7 @@ logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.12.0.dev0") check_min_version("4.12.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
...@@ -102,7 +102,6 @@ class DataTrainingArguments: ...@@ -102,7 +102,6 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."})
def __post_init__(self): def __post_init__(self):
data_files = dict() data_files = dict()
...@@ -210,35 +209,6 @@ def main(): ...@@ -210,35 +209,6 @@ def main():
task="image-classification", task="image-classification",
) )
# Define torchvision transforms to be applied to each image.
normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
_train_transforms = Compose(
[
RandomResizedCrop(data_args.image_size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
_val_transforms = Compose(
[
Resize(data_args.image_size),
CenterCrop(data_args.image_size),
ToTensor(),
normalize,
]
)
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
return example_batch
def val_transforms(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
return example_batch
# If we don't have a validation split, split off a percentage of train as validation. # If we don't have a validation split, split off a percentage of train as validation.
data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
...@@ -281,20 +251,42 @@ def main(): ...@@ -281,20 +251,42 @@ def main():
revision=model_args.model_revision, revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
# NOTE - We aren't directly using this feature extractor since we defined custom transforms above.
# We initialize this instance below and pass it to Trainer to ensure that the feature extraction
# config, preprocessor_config.json, is included in output directories.
# This way if we push a model to the hub, the inference widget will work.
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path, model_args.feature_extractor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
size=data_args.image_size,
image_mean=normalize.mean,
image_std=normalize.std,
) )
# Define torchvision transforms to be applied to each image.
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_train_transforms = Compose(
[
RandomResizedCrop(feature_extractor.size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
_val_transforms = Compose(
[
Resize(feature_extractor.size),
CenterCrop(feature_extractor.size),
ToTensor(),
normalize,
]
)
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
return example_batch
def val_transforms(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
return example_batch
if training_args.do_train: if training_args.do_train:
if "train" not in ds: if "train" not in ds:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment