Unverified Commit 56d001b2 authored by Pavel Iakubovskii's avatar Pavel Iakubovskii Committed by GitHub
Browse files

Fix and simplify semantic-segmentation example (#30145)

* Remove unused augmentation

* Fix pad_if_smaller() and remove unused augmentation

* Add indentation

* Fix requirements

* Update dataset use instructions

* Replace transforms with albumentations

* Replace identity transform with None

* Fixing formatting

* Fixed comment place
parent 41579763
...@@ -25,3 +25,4 @@ torchaudio ...@@ -25,3 +25,4 @@ torchaudio
jiwer jiwer
librosa librosa
evaluate >= 0.2.0 evaluate >= 0.2.0
albumentations
...@@ -97,6 +97,10 @@ The script leverages the [🤗 Trainer API](https://huggingface.co/docs/transfor ...@@ -97,6 +97,10 @@ The script leverages the [🤗 Trainer API](https://huggingface.co/docs/transfor
Here we show how to fine-tune a [SegFormer](https://huggingface.co/nvidia/mit-b0) model on the [segments/sidewalk-semantic](https://huggingface.co/datasets/segments/sidewalk-semantic) dataset: Here we show how to fine-tune a [SegFormer](https://huggingface.co/nvidia/mit-b0) model on the [segments/sidewalk-semantic](https://huggingface.co/datasets/segments/sidewalk-semantic) dataset:
In order to use `segments/sidewalk-semantic`:
- Log in to Hugging Face with `huggingface-cli login` (token can be accessed [here](https://huggingface.co/settings/tokens)).
- Accept terms of use for `sidewalk-semantic` on [dataset page](https://huggingface.co/datasets/segments/sidewalk-semantic).
```bash ```bash
python run_semantic_segmentation.py \ python run_semantic_segmentation.py \
--model_name_or_path nvidia/mit-b0 \ --model_name_or_path nvidia/mit-b0 \
...@@ -105,7 +109,6 @@ python run_semantic_segmentation.py \ ...@@ -105,7 +109,6 @@ python run_semantic_segmentation.py \
--remove_unused_columns False \ --remove_unused_columns False \
--do_train \ --do_train \
--do_eval \ --do_eval \
--evaluation_strategy steps \
--push_to_hub \ --push_to_hub \
--push_to_hub_model_id segformer-finetuned-sidewalk-10k-steps \ --push_to_hub_model_id segformer-finetuned-sidewalk-10k-steps \
--max_steps 10000 \ --max_steps 10000 \
......
git://github.com/huggingface/accelerate.git
datasets >= 2.0.0 datasets >= 2.0.0
torch >= 1.3 torch >= 1.3
evaluate accelerate
\ No newline at end of file evaluate
Pillow
albumentations
\ No newline at end of file
...@@ -16,21 +16,20 @@ ...@@ -16,21 +16,20 @@
import json import json
import logging import logging
import os import os
import random
import sys import sys
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial
from typing import Optional from typing import Optional
import albumentations as A
import evaluate import evaluate
import numpy as np import numpy as np
import torch import torch
from albumentations.pytorch import ToTensorV2
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from PIL import Image
from torch import nn from torch import nn
from torchvision import transforms
from torchvision.transforms import functional
import transformers import transformers
from transformers import ( from transformers import (
...@@ -57,118 +56,19 @@ check_min_version("4.40.0.dev0") ...@@ -57,118 +56,19 @@ check_min_version("4.40.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
def pad_if_smaller(img, size, fill=0): def reduce_labels_transform(labels: np.ndarray, **kwargs) -> np.ndarray:
size = (size, size) if isinstance(size, int) else size """Set `0` label as with value 255 and then reduce all other labels by 1.
original_width, original_height = img.size
pad_height = size[1] - original_height if original_height < size[1] else 0
pad_width = size[0] - original_width if original_width < size[0] else 0
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
return img
Example:
Initial class labels: 0 - background; 1 - road; 2 - car;
Transformed class labels: 255 - background; 0 - road; 1 - car;
class Compose: **kwargs are required to use this function with albumentations.
def __init__(self, transforms): """
self.transforms = transforms labels[labels == 0] = 255
labels = labels - 1
def __call__(self, image, target): labels[labels == 254] = 255
for t in self.transforms: return labels
image, target = t(image, target)
return image, target
class Identity:
def __init__(self):
pass
def __call__(self, image, target):
return image, target
class Resize:
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = functional.resize(image, self.size)
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target
class RandomResize:
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size is None:
max_size = min_size
self.max_size = max_size
def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = functional.resize(image, size)
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target
class RandomCrop:
def __init__(self, size):
self.size = size if isinstance(size, tuple) else (size, size)
def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
crop_params = transforms.RandomCrop.get_params(image, self.size)
image = functional.crop(image, *crop_params)
target = functional.crop(target, *crop_params)
return image, target
class RandomHorizontalFlip:
def __init__(self, flip_prob):
self.flip_prob = flip_prob
def __call__(self, image, target):
if random.random() < self.flip_prob:
image = functional.hflip(image)
target = functional.hflip(target)
return image, target
class PILToTensor:
def __call__(self, image, target):
image = functional.pil_to_tensor(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
class ConvertImageDtype:
def __init__(self, dtype):
self.dtype = dtype
def __call__(self, image, target):
image = functional.convert_image_dtype(image, self.dtype)
return image, target
class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target):
image = functional.normalize(image, mean=self.mean, std=self.std)
return image, target
class ReduceLabels:
def __call__(self, image, target):
if not isinstance(target, np.ndarray):
target = np.array(target).astype(np.uint8)
# avoid using underflow conversion
target[target == 0] = 255
target = target - 1
target[target == 254] = 255
target = Image.fromarray(target)
return image, target
@dataclass @dataclass
...@@ -365,7 +265,7 @@ def main(): ...@@ -365,7 +265,7 @@ def main():
id2label = {int(k): v for k, v in id2label.items()} id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: str(k) for k, v in id2label.items()} label2id = {v: str(k) for k, v in id2label.items()}
# Load the mean IoU metric from the datasets package # Load the mean IoU metric from the evaluate package
metric = evaluate.load("mean_iou", cache_dir=model_args.cache_dir) metric = evaluate.load("mean_iou", cache_dir=model_args.cache_dir)
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a # Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
...@@ -424,64 +324,62 @@ def main(): ...@@ -424,64 +324,62 @@ def main():
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
) )
# `reduce_labels` is a property of dataset labels, in case we use image_processor
# pretrained on another dataset we should override the default setting
image_processor.do_reduce_labels = data_args.reduce_labels
# Define torchvision transforms to be applied to each image + target. # Define transforms to be applied to each image and target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
if "shortest_edge" in image_processor.size: if "shortest_edge" in image_processor.size:
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable. # We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]) height, width = image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]
else: else:
size = (image_processor.size["height"], image_processor.size["width"]) height, width = image_processor.size["height"], image_processor.size["width"]
train_transforms = Compose( train_transforms = A.Compose(
[ [
ReduceLabels() if data_args.reduce_labels else Identity(), A.Lambda(
RandomCrop(size=size), name="reduce_labels",
RandomHorizontalFlip(flip_prob=0.5), mask=reduce_labels_transform if data_args.reduce_labels else None,
PILToTensor(), p=1.0,
ConvertImageDtype(torch.float), ),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std), # pad image with 255, because it is ignored by loss
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
A.RandomCrop(height=height, width=width, p=1.0),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ToTensorV2(),
] ]
) )
# Define torchvision transform to be applied to each image. val_transforms = A.Compose(
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
val_transforms = Compose(
[ [
ReduceLabels() if data_args.reduce_labels else Identity(), A.Lambda(
Resize(size=size), name="reduce_labels",
PILToTensor(), mask=reduce_labels_transform if data_args.reduce_labels else None,
ConvertImageDtype(torch.float), p=1.0,
Normalize(mean=image_processor.image_mean, std=image_processor.image_std), ),
A.Resize(height=height, width=width, p=1.0),
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ToTensorV2(),
] ]
) )
def preprocess_train(example_batch): def preprocess_batch(example_batch, transforms: A.Compose):
pixel_values = [] pixel_values = []
labels = [] labels = []
for image, target in zip(example_batch["image"], example_batch["label"]): for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = train_transforms(image.convert("RGB"), target) transformed = transforms(image=np.array(image.convert("RGB")), mask=np.array(target))
pixel_values.append(image) pixel_values.append(transformed["image"])
labels.append(target) labels.append(transformed["mask"])
encoding = {} encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values) encoding["pixel_values"] = torch.stack(pixel_values).to(torch.float)
encoding["labels"] = torch.stack(labels) encoding["labels"] = torch.stack(labels).to(torch.long)
return encoding return encoding
def preprocess_val(example_batch): # Preprocess function for dataset should have only one argument,
pixel_values = [] # so we use partial to pass the transforms
labels = [] preprocess_train_batch_fn = partial(preprocess_batch, transforms=train_transforms)
for image, target in zip(example_batch["image"], example_batch["label"]): preprocess_val_batch_fn = partial(preprocess_batch, transforms=val_transforms)
image, target = val_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)
encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)
return encoding
if training_args.do_train: if training_args.do_train:
if "train" not in dataset: if "train" not in dataset:
...@@ -491,7 +389,7 @@ def main(): ...@@ -491,7 +389,7 @@ def main():
dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples)) dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
) )
# Set the training transforms # Set the training transforms
dataset["train"].set_transform(preprocess_train) dataset["train"].set_transform(preprocess_train_batch_fn)
if training_args.do_eval: if training_args.do_eval:
if "validation" not in dataset: if "validation" not in dataset:
...@@ -501,7 +399,7 @@ def main(): ...@@ -501,7 +399,7 @@ def main():
dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
) )
# Set the validation transforms # Set the validation transforms
dataset["validation"].set_transform(preprocess_val) dataset["validation"].set_transform(preprocess_val_batch_fn)
# Initialize our trainer # Initialize our trainer
trainer = Trainer( trainer = Trainer(
......
...@@ -18,9 +18,10 @@ import argparse ...@@ -18,9 +18,10 @@ import argparse
import json import json
import math import math
import os import os
import random from functools import partial
from pathlib import Path from pathlib import Path
import albumentations as A
import datasets import datasets
import evaluate import evaluate
import numpy as np import numpy as np
...@@ -28,12 +29,10 @@ import torch ...@@ -28,12 +29,10 @@ import torch
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from albumentations.pytorch import ToTensorV2
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from PIL import Image
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import functional
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
...@@ -57,123 +56,23 @@ logger = get_logger(__name__) ...@@ -57,123 +56,23 @@ logger = get_logger(__name__)
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
def pad_if_smaller(img, size, fill=0): def reduce_labels_transform(labels: np.ndarray, **kwargs) -> np.ndarray:
min_size = min(img.size) """Set `0` label as with value 255 and then reduce all other labels by 1.
if min_size < size:
original_width, original_height = img.size
pad_height = size - original_height if original_height < size else 0
pad_width = size - original_width if original_width < size else 0
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
return img
Example:
Initial class labels: 0 - background; 1 - road; 2 - car;
Transformed class labels: 255 - background; 0 - road; 1 - car;
class Compose: **kwargs are required to use this function with albumentations.
def __init__(self, transforms): """
self.transforms = transforms labels[labels == 0] = 255
labels = labels - 1
def __call__(self, image, target): labels[labels == 254] = 255
for t in self.transforms: return labels
image, target = t(image, target)
return image, target
class Identity:
def __init__(self):
pass
def __call__(self, image, target):
return image, target
class Resize:
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = functional.resize(image, self.size)
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target
class RandomResize:
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size is None:
max_size = min_size
self.max_size = max_size
def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = functional.resize(image, size)
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target
class RandomCrop:
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
crop_params = transforms.RandomCrop.get_params(image, (self.size, self.size))
image = functional.crop(image, *crop_params)
target = functional.crop(target, *crop_params)
return image, target
class RandomHorizontalFlip:
def __init__(self, flip_prob):
self.flip_prob = flip_prob
def __call__(self, image, target):
if random.random() < self.flip_prob:
image = functional.hflip(image)
target = functional.hflip(target)
return image, target
class PILToTensor:
def __call__(self, image, target):
image = functional.pil_to_tensor(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
class ConvertImageDtype:
def __init__(self, dtype):
self.dtype = dtype
def __call__(self, image, target):
image = functional.convert_image_dtype(image, self.dtype)
return image, target
class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target):
image = functional.normalize(image, mean=self.mean, std=self.std)
return image, target
class ReduceLabels:
def __call__(self, image, target):
if not isinstance(target, np.ndarray):
target = np.array(target).astype(np.uint8)
# avoid using underflow conversion
target[target == 0] = 255
target = target - 1
target[target == 254] = 255
target = Image.fromarray(target)
return image, target
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") parser = argparse.ArgumentParser(description="Finetune a transformers model on a image semantic segmentation task")
parser.add_argument( parser.add_argument(
"--model_name_or_path", "--model_name_or_path",
type=str, type=str,
...@@ -418,69 +317,58 @@ def main(): ...@@ -418,69 +317,58 @@ def main():
model = AutoModelForSemanticSegmentation.from_pretrained( model = AutoModelForSemanticSegmentation.from_pretrained(
args.model_name_or_path, config=config, trust_remote_code=args.trust_remote_code args.model_name_or_path, config=config, trust_remote_code=args.trust_remote_code
) )
# `reduce_labels` is a property of dataset labels, in case we use image_processor
# pretrained on another dataset we should override the default setting
image_processor.do_reduce_labels = args.reduce_labels
# Preprocessing the datasets # Define transforms to be applied to each image and target.
# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
if "shortest_edge" in image_processor.size: if "shortest_edge" in image_processor.size:
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable. # We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]) height, width = image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]
else: else:
size = (image_processor.size["height"], image_processor.size["width"]) height, width = image_processor.size["height"], image_processor.size["width"]
train_transforms = Compose( train_transforms = A.Compose(
[ [
ReduceLabels() if args.reduce_labels else Identity(), A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.reduce_labels else None, p=1.0),
RandomCrop(size=size), # pad image with 255, because it is ignored by loss
RandomHorizontalFlip(flip_prob=0.5), A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
PILToTensor(), A.RandomCrop(height=height, width=width, p=1.0),
ConvertImageDtype(torch.float), A.HorizontalFlip(p=0.5),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std), A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ToTensorV2(),
] ]
) )
# Define torchvision transform to be applied to each image. val_transforms = A.Compose(
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
val_transforms = Compose(
[ [
ReduceLabels() if args.reduce_labels else Identity(), A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.reduce_labels else None, p=1.0),
Resize(size=size), A.Resize(height=height, width=width, p=1.0),
PILToTensor(), A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ConvertImageDtype(torch.float), ToTensorV2(),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
] ]
) )
def preprocess_train(example_batch): def preprocess_batch(example_batch, transforms: A.Compose):
pixel_values = [] pixel_values = []
labels = [] labels = []
for image, target in zip(example_batch["image"], example_batch["label"]): for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = train_transforms(image.convert("RGB"), target) transformed = transforms(image=np.array(image.convert("RGB")), mask=np.array(target))
pixel_values.append(image) pixel_values.append(transformed["image"])
labels.append(target) labels.append(transformed["mask"])
encoding = {} encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values) encoding["pixel_values"] = torch.stack(pixel_values).to(torch.float)
encoding["labels"] = torch.stack(labels) encoding["labels"] = torch.stack(labels).to(torch.long)
return encoding return encoding
def preprocess_val(example_batch): # Preprocess function for dataset should have only one input argument,
pixel_values = [] # so we use partial to pass transforms
labels = [] preprocess_train_batch_fn = partial(preprocess_batch, transforms=train_transforms)
for image, target in zip(example_batch["image"], example_batch["label"]): preprocess_val_batch_fn = partial(preprocess_batch, transforms=val_transforms)
image, target = val_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)
encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)
return encoding
with accelerator.main_process_first(): with accelerator.main_process_first():
train_dataset = dataset["train"].with_transform(preprocess_train) train_dataset = dataset["train"].with_transform(preprocess_train_batch_fn)
eval_dataset = dataset["validation"].with_transform(preprocess_val) eval_dataset = dataset["validation"].with_transform(preprocess_val_batch_fn)
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
...@@ -726,7 +614,7 @@ def main(): ...@@ -726,7 +614,7 @@ def main():
f"eval_{k}": v.tolist() if isinstance(v, np.ndarray) else v for k, v in eval_metrics.items() f"eval_{k}": v.tolist() if isinstance(v, np.ndarray) else v for k, v in eval_metrics.items()
} }
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump(all_results, f) json.dump(all_results, f, indent=2)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment