Unverified Commit 56d001b2 authored by Pavel Iakubovskii's avatar Pavel Iakubovskii Committed by GitHub
Browse files

Fix and simplify semantic-segmentation example (#30145)

* Remove unused augmentation

* Fix pad_if_smaller() and remove unused augmentation

* Add indentation

* Fix requirements

* Update dataset use instructions

* Replace transforms with albumentations

* Replace identity transform with None

* Fixing formatting

* Fixed comment place
parent 41579763
......@@ -25,3 +25,4 @@ torchaudio
jiwer
librosa
evaluate >= 0.2.0
albumentations
......@@ -97,6 +97,10 @@ The script leverages the [🤗 Trainer API](https://huggingface.co/docs/transfor
Here we show how to fine-tune a [SegFormer](https://huggingface.co/nvidia/mit-b0) model on the [segments/sidewalk-semantic](https://huggingface.co/datasets/segments/sidewalk-semantic) dataset:
In order to use `segments/sidewalk-semantic`:
- Log in to Hugging Face with `huggingface-cli login` (token can be accessed [here](https://huggingface.co/settings/tokens)).
- Accept terms of use for `sidewalk-semantic` on [dataset page](https://huggingface.co/datasets/segments/sidewalk-semantic).
```bash
python run_semantic_segmentation.py \
--model_name_or_path nvidia/mit-b0 \
......@@ -105,7 +109,6 @@ python run_semantic_segmentation.py \
--remove_unused_columns False \
--do_train \
--do_eval \
--evaluation_strategy steps \
--push_to_hub \
--push_to_hub_model_id segformer-finetuned-sidewalk-10k-steps \
--max_steps 10000 \
......
git://github.com/huggingface/accelerate.git
datasets >= 2.0.0
torch >= 1.3
accelerate
evaluate
Pillow
albumentations
\ No newline at end of file
......@@ -16,21 +16,20 @@
import json
import logging
import os
import random
import sys
import warnings
from dataclasses import dataclass, field
from functools import partial
from typing import Optional
import albumentations as A
import evaluate
import numpy as np
import torch
from albumentations.pytorch import ToTensorV2
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from PIL import Image
from torch import nn
from torchvision import transforms
from torchvision.transforms import functional
import transformers
from transformers import (
......@@ -57,118 +56,19 @@ check_min_version("4.40.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
def pad_if_smaller(img, size, fill=0):
size = (size, size) if isinstance(size, int) else size
original_width, original_height = img.size
pad_height = size[1] - original_height if original_height < size[1] else 0
pad_width = size[0] - original_width if original_width < size[0] else 0
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
return img
def reduce_labels_transform(labels: np.ndarray, **kwargs) -> np.ndarray:
"""Set `0` label as with value 255 and then reduce all other labels by 1.
Example:
Initial class labels: 0 - background; 1 - road; 2 - car;
Transformed class labels: 255 - background; 0 - road; 1 - car;
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class Identity:
def __init__(self):
pass
def __call__(self, image, target):
return image, target
class Resize:
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = functional.resize(image, self.size)
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target
class RandomResize:
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size is None:
max_size = min_size
self.max_size = max_size
def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = functional.resize(image, size)
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target
class RandomCrop:
def __init__(self, size):
self.size = size if isinstance(size, tuple) else (size, size)
def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
crop_params = transforms.RandomCrop.get_params(image, self.size)
image = functional.crop(image, *crop_params)
target = functional.crop(target, *crop_params)
return image, target
class RandomHorizontalFlip:
def __init__(self, flip_prob):
self.flip_prob = flip_prob
def __call__(self, image, target):
if random.random() < self.flip_prob:
image = functional.hflip(image)
target = functional.hflip(target)
return image, target
class PILToTensor:
def __call__(self, image, target):
image = functional.pil_to_tensor(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
class ConvertImageDtype:
def __init__(self, dtype):
self.dtype = dtype
def __call__(self, image, target):
image = functional.convert_image_dtype(image, self.dtype)
return image, target
class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target):
image = functional.normalize(image, mean=self.mean, std=self.std)
return image, target
class ReduceLabels:
def __call__(self, image, target):
if not isinstance(target, np.ndarray):
target = np.array(target).astype(np.uint8)
# avoid using underflow conversion
target[target == 0] = 255
target = target - 1
target[target == 254] = 255
target = Image.fromarray(target)
return image, target
**kwargs are required to use this function with albumentations.
"""
labels[labels == 0] = 255
labels = labels - 1
labels[labels == 254] = 255
return labels
@dataclass
......@@ -365,7 +265,7 @@ def main():
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: str(k) for k, v in id2label.items()}
# Load the mean IoU metric from the datasets package
# Load the mean IoU metric from the evaluate package
metric = evaluate.load("mean_iou", cache_dir=model_args.cache_dir)
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
......@@ -424,64 +324,62 @@ def main():
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
# `reduce_labels` is a property of dataset labels, in case we use image_processor
# pretrained on another dataset we should override the default setting
image_processor.do_reduce_labels = data_args.reduce_labels
# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
# Define transforms to be applied to each image and target.
if "shortest_edge" in image_processor.size:
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])
height, width = image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
train_transforms = Compose(
height, width = image_processor.size["height"], image_processor.size["width"]
train_transforms = A.Compose(
[
ReduceLabels() if data_args.reduce_labels else Identity(),
RandomCrop(size=size),
RandomHorizontalFlip(flip_prob=0.5),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
A.Lambda(
name="reduce_labels",
mask=reduce_labels_transform if data_args.reduce_labels else None,
p=1.0,
),
# pad image with 255, because it is ignored by loss
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
A.RandomCrop(height=height, width=width, p=1.0),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ToTensorV2(),
]
)
# Define torchvision transform to be applied to each image.
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
val_transforms = Compose(
val_transforms = A.Compose(
[
ReduceLabels() if data_args.reduce_labels else Identity(),
Resize(size=size),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
A.Lambda(
name="reduce_labels",
mask=reduce_labels_transform if data_args.reduce_labels else None,
p=1.0,
),
A.Resize(height=height, width=width, p=1.0),
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ToTensorV2(),
]
)
def preprocess_train(example_batch):
def preprocess_batch(example_batch, transforms: A.Compose):
pixel_values = []
labels = []
for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = train_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)
transformed = transforms(image=np.array(image.convert("RGB")), mask=np.array(target))
pixel_values.append(transformed["image"])
labels.append(transformed["mask"])
encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)
encoding["pixel_values"] = torch.stack(pixel_values).to(torch.float)
encoding["labels"] = torch.stack(labels).to(torch.long)
return encoding
def preprocess_val(example_batch):
pixel_values = []
labels = []
for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = val_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)
encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)
return encoding
# Preprocess function for dataset should have only one argument,
# so we use partial to pass the transforms
preprocess_train_batch_fn = partial(preprocess_batch, transforms=train_transforms)
preprocess_val_batch_fn = partial(preprocess_batch, transforms=val_transforms)
if training_args.do_train:
if "train" not in dataset:
......@@ -491,7 +389,7 @@ def main():
dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
# Set the training transforms
dataset["train"].set_transform(preprocess_train)
dataset["train"].set_transform(preprocess_train_batch_fn)
if training_args.do_eval:
if "validation" not in dataset:
......@@ -501,7 +399,7 @@ def main():
dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# Set the validation transforms
dataset["validation"].set_transform(preprocess_val)
dataset["validation"].set_transform(preprocess_val_batch_fn)
# Initialize our trainer
trainer = Trainer(
......
......@@ -18,9 +18,10 @@ import argparse
import json
import math
import os
import random
from functools import partial
from pathlib import Path
import albumentations as A
import datasets
import evaluate
import numpy as np
......@@ -28,12 +29,10 @@ import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from albumentations.pytorch import ToTensorV2
from datasets import load_dataset
from huggingface_hub import HfApi, hf_hub_download
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import functional
from tqdm.auto import tqdm
import transformers
......@@ -57,123 +56,23 @@ logger = get_logger(__name__)
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
def pad_if_smaller(img, size, fill=0):
min_size = min(img.size)
if min_size < size:
original_width, original_height = img.size
pad_height = size - original_height if original_height < size else 0
pad_width = size - original_width if original_width < size else 0
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
return img
def reduce_labels_transform(labels: np.ndarray, **kwargs) -> np.ndarray:
"""Set `0` label as with value 255 and then reduce all other labels by 1.
Example:
Initial class labels: 0 - background; 1 - road; 2 - car;
Transformed class labels: 255 - background; 0 - road; 1 - car;
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class Identity:
def __init__(self):
pass
def __call__(self, image, target):
return image, target
class Resize:
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = functional.resize(image, self.size)
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target
class RandomResize:
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size is None:
max_size = min_size
self.max_size = max_size
def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = functional.resize(image, size)
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target
class RandomCrop:
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
crop_params = transforms.RandomCrop.get_params(image, (self.size, self.size))
image = functional.crop(image, *crop_params)
target = functional.crop(target, *crop_params)
return image, target
class RandomHorizontalFlip:
def __init__(self, flip_prob):
self.flip_prob = flip_prob
def __call__(self, image, target):
if random.random() < self.flip_prob:
image = functional.hflip(image)
target = functional.hflip(target)
return image, target
class PILToTensor:
def __call__(self, image, target):
image = functional.pil_to_tensor(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
class ConvertImageDtype:
def __init__(self, dtype):
self.dtype = dtype
def __call__(self, image, target):
image = functional.convert_image_dtype(image, self.dtype)
return image, target
class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target):
image = functional.normalize(image, mean=self.mean, std=self.std)
return image, target
class ReduceLabels:
def __call__(self, image, target):
if not isinstance(target, np.ndarray):
target = np.array(target).astype(np.uint8)
# avoid using underflow conversion
target[target == 0] = 255
target = target - 1
target[target == 254] = 255
target = Image.fromarray(target)
return image, target
**kwargs are required to use this function with albumentations.
"""
labels[labels == 0] = 255
labels = labels - 1
labels[labels == 254] = 255
return labels
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
parser = argparse.ArgumentParser(description="Finetune a transformers model on a image semantic segmentation task")
parser.add_argument(
"--model_name_or_path",
type=str,
......@@ -418,69 +317,58 @@ def main():
model = AutoModelForSemanticSegmentation.from_pretrained(
args.model_name_or_path, config=config, trust_remote_code=args.trust_remote_code
)
# `reduce_labels` is a property of dataset labels, in case we use image_processor
# pretrained on another dataset we should override the default setting
image_processor.do_reduce_labels = args.reduce_labels
# Preprocessing the datasets
# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
# Define transforms to be applied to each image and target.
if "shortest_edge" in image_processor.size:
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])
height, width = image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
train_transforms = Compose(
height, width = image_processor.size["height"], image_processor.size["width"]
train_transforms = A.Compose(
[
ReduceLabels() if args.reduce_labels else Identity(),
RandomCrop(size=size),
RandomHorizontalFlip(flip_prob=0.5),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.reduce_labels else None, p=1.0),
# pad image with 255, because it is ignored by loss
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
A.RandomCrop(height=height, width=width, p=1.0),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ToTensorV2(),
]
)
# Define torchvision transform to be applied to each image.
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
val_transforms = Compose(
val_transforms = A.Compose(
[
ReduceLabels() if args.reduce_labels else Identity(),
Resize(size=size),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
A.Lambda(name="reduce_labels", mask=reduce_labels_transform if args.reduce_labels else None, p=1.0),
A.Resize(height=height, width=width, p=1.0),
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ToTensorV2(),
]
)
def preprocess_train(example_batch):
def preprocess_batch(example_batch, transforms: A.Compose):
pixel_values = []
labels = []
for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = train_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)
transformed = transforms(image=np.array(image.convert("RGB")), mask=np.array(target))
pixel_values.append(transformed["image"])
labels.append(transformed["mask"])
encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)
encoding["pixel_values"] = torch.stack(pixel_values).to(torch.float)
encoding["labels"] = torch.stack(labels).to(torch.long)
return encoding
def preprocess_val(example_batch):
pixel_values = []
labels = []
for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = val_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)
encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)
return encoding
# Preprocess function for dataset should have only one input argument,
# so we use partial to pass transforms
preprocess_train_batch_fn = partial(preprocess_batch, transforms=train_transforms)
preprocess_val_batch_fn = partial(preprocess_batch, transforms=val_transforms)
with accelerator.main_process_first():
train_dataset = dataset["train"].with_transform(preprocess_train)
eval_dataset = dataset["validation"].with_transform(preprocess_val)
train_dataset = dataset["train"].with_transform(preprocess_train_batch_fn)
eval_dataset = dataset["validation"].with_transform(preprocess_val_batch_fn)
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
......@@ -726,7 +614,7 @@ def main():
f"eval_{k}": v.tolist() if isinstance(v, np.ndarray) else v for k, v in eval_metrics.items()
}
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump(all_results, f)
json.dump(all_results, f, indent=2)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment