Commit 2eea6216 authored by rprenger's avatar rprenger
Browse files

Merging with main and fixing merge conflict

parents ed6806ac 5f694372
......@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron import get_args
from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
GENERATE_NUM = 0
BEAM_NUM = 1
lock = threading.Lock()
class MegatronGenerate(Resource):
......@@ -34,6 +36,11 @@ class MegatronGenerate(Resource):
choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, 0)
@staticmethod
def send_do_beam_search():
choice = torch.cuda.LongTensor([BEAM_NUM])
torch.distributed.broadcast(choice, 0)
def put(self):
args = get_args()
......@@ -148,15 +155,57 @@ class MegatronGenerate(Resource):
if not isinstance(no_log, bool):
return "no_log must be a boolean value"
beam_width = None
if "beam_width" in request.get_json():
beam_width = request.get_json()["beam_width"]
if not isinstance(beam_width, int):
return "beam_width must be integer"
if beam_width < 1:
return "beam_width must be an integer > 1"
if len(prompts) > 1:
return "When doing beam_search, batch size must be 1"
stop_token=50256
if "stop_token" in request.get_json():
stop_token = request.get_json()["stop_token"]
if not isinstance(stop_token, int):
return "stop_token must be an integer"
length_penalty = 1
if "length_penalty" in request.get_json():
length_penalty = request.get_json()["length_penalty"]
if not isinstance(length_penalty, float):
return "length_penalty must be a float"
with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log:
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
try:
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
if beam_width is not None:
MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search
response, response_seg, response_scores = \
beam_search_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size = beam_width,
add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=beam_width, # Returning whole beam
length_penalty=length_penalty
)
return jsonify({"text": response,
"segments": response_seg,
"scores": response_scores})
else:
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
......@@ -171,13 +220,15 @@ class MegatronGenerate(Resource):
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
random_seed=random_seed)
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
class MegatronServer(object):
def __init__(self, model):
......
......@@ -42,6 +42,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
......@@ -99,6 +100,8 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
......@@ -361,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func,
args = get_args()
model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None:
......@@ -409,78 +411,44 @@ def train_step(forward_step_func, data_iterator,
partition.zero_grad_buffer()
optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# Empty unused memory
# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Reduce gradients.
timers('backward-reduce-model-grads').start()
optimizer.reduce_model_grads(args, timers)
timers('backward-reduce-model-grads').stop()
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop()
# Gather params.
if update_successful:
timers('backward-gather-model-params').start()
optimizer.gather_model_params(args, timers)
timers('backward-gather-model-params').stop()
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
......@@ -491,7 +459,7 @@ def train_step(forward_step_func, data_iterator,
else:
skipped_iter = 1
# Empty unused memory
# Empty unused memory.
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
......@@ -558,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce')
add_to_logging('backward-layernorm-all-reduce')
add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-reduce-model-grads')
add_to_logging('backward-gather-model-params')
add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-count-zeros')
add_to_logging('optimizer-inner-step')
add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer')
add_to_logging('batch-generator')
......
......@@ -24,7 +24,6 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.model.module import param_is_not_shared
......@@ -204,3 +203,22 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
......@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step,
valid_dataloader, model,
iteration, False)
iteration, None, False)
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
......
......@@ -15,12 +15,15 @@
"""Vision-classification finetuning/evaluation."""
from megatron import get_args
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import print_rank_0
from megatron.model.vit_model import VitModel
from megatron.model.vision.classification import VitClassificationModel
from megatron.data.vit_dataset import build_train_valid_datasets
from tasks.vision.eval_utils import accuracy_func_provider
from tasks.vision.classification.eval_utils import accuracy_func_provider
from tasks.vision.finetune_utils import finetune
from megatron.utils import average_losses_across_data_parallel_group
def classification():
......@@ -30,7 +33,7 @@ def classification():
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
crop_size=args.img_dim,
image_size=(args.img_h, args.img_w),
)
return train_ds, valid_ds
......@@ -40,16 +43,52 @@ def classification():
print_rank_0("building classification model for ImageNet ...")
return VitModel(num_classes=args.num_classes, finetune=True,
pre_process=pre_process, post_process=post_process)
return VitClassificationModel(num_classes=args.num_classes, finetune=True,
pre_process=pre_process, post_process=post_process)
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
return images, labels
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss = F.cross_entropy(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers("batch generator").start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, labels)
"""Finetune/evaluate."""
finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=accuracy_func_provider,
)
def main():
classification()
......@@ -33,11 +33,10 @@ def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args = get_args()
data_path = args.data_path
crop_size = args.img_dim
crop_size = (args.img_h, args.img_w)
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# Build dataloaders.
val_data_path = os.path.join(data_path[0], "val")
val_data_path = data_path[1]
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_val = transforms.Compose(
[
......@@ -54,6 +53,7 @@ def accuracy_func_provider():
args.micro_batch_size,
num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1),
shuffle=False
)
def metrics_func(model, epoch):
......@@ -71,7 +71,6 @@ def accuracy_func_provider():
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
args = get_args()
forward_backward_func = get_forward_backward_func()
for m in model:
m.eval()
......@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch):
images, labels = process_batch(batch_)
# Forward model.
args = get_args()
output_tensor = model(images)
return output_tensor, partial(loss_func, labels)
......
......@@ -17,11 +17,10 @@
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron import mpu, utils
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
......@@ -29,7 +28,10 @@ from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import average_losses_across_data_parallel_group
from megatron.utils import average_losses_across_data_parallel_group, print_params_min_max_norm
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module, ModelType
def process_batch(batch):
......@@ -39,45 +41,16 @@ def process_batch(batch):
return images, labels
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss = F.cross_entropy(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers("batch generator").start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
def build_data_loader(dataset, micro_batch_size,
num_workers, drop_last, shuffle):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank
dataset, num_replicas=world_size, rank=rank,
drop_last=drop_last, shuffle=shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
......@@ -112,14 +85,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
args.num_workers, False, True)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
args.num_workers, True, False)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set batch_size arguments
......@@ -132,6 +105,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
return train_dataloader, valid_dataloader
def _train(
model,
optimizer,
......@@ -140,6 +114,7 @@ def _train(
train_dataloader,
valid_dataloader,
end_of_epoch_callback,
process_non_loss_data_func=None
):
"""Train the model."""
args = get_args()
......@@ -167,6 +142,7 @@ def _train(
# Set the data loader epoch to shuffle the index iterator.
train_dataloader.sampler.set_epoch(args.seed + epoch)
train_dataloader.dataset.set_epoch(epoch)
# For all the batches in the dataset.
for iteration_, batch in enumerate(train_dataloader):
......@@ -185,8 +161,6 @@ def _train(
# Logging.
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(
losses_dict,
......@@ -202,20 +176,16 @@ def _train(
)
# Autoresume
if args.adlr_autoresume and (
iteration % args.adlr_autoresume_interval == 0
):
check_adlr_autoresume_termination(
iteration, model, optimizer, opt_param_scheduler
)
if args.adlr_autoresume and \
iteration % args.adlr_autoresume_interval == 0:
check_adlr_autoresume_termination(iteration, model, optimizer,
opt_param_scheduler)
# Checkpointing
if (
args.save
and args.save_interval
and iteration % args.save_interval == 0
):
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer,
opt_param_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
......@@ -226,13 +196,10 @@ def _train(
valid_dataloader,
model,
iteration,
process_non_loss_data_func,
False,
)
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
end_of_epoch_callback(model, epoch)
......@@ -241,7 +208,9 @@ def _train(
def finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
forward_step,
model_type=ModelType.encoder_or_decoder,
process_non_loss_data_func=None,
end_of_epoch_callback_provider=None,
):
"""Main finetune function used across all tasks."""
......@@ -266,7 +235,12 @@ def finetune(
# Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider)
model, optimizer, opt_param_scheduler = \
setup_model_and_optimizer(
model_provider,
model_type,
scale_lr_cond=lambda name, param: ".head." in name,
lr_mult=args.head_lr_mult)
timers("model and optimizer").stop()
# If pretrained checkpoint is provided and we have not trained for
......@@ -274,13 +248,34 @@ def finetune(
# checkpoint.
timers("pretrained checkpoint").start()
if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
_ = load_checkpoint(model, None, None, strict=False)
args.load = original_load
if args.pretrained_checkpoint_type == 'default':
original_load = args.load
args.load = args.pretrained_checkpoint
_ = load_checkpoint(model, None, None, strict=False)
args.load = original_load
elif args.pretrained_checkpoint_type == 'external':
unwrap_model = utils.unwrap_model(model)
state_dict = torch.load(args.pretrained_checkpoint,
map_location="cpu")
unwrap_model[0].module.backbone.load_state_dict(state_dict,
strict=False)
elif args.pretrained_checkpoint_type == 'constrastive':
unwrap_model = utils.unwrap_model(model)
state_dict = torch.load(args.pretrained_checkpoint,
map_location="cpu")
state_dict = state_dict["model"]
state_dict = {k.replace("teacher.backbone.", ""): v
for k, v in state_dict.items()
if k.startswith("teacher.backbone.")}
unwrap_model[0].module.backbone.load_state_dict(state_dict,
strict=False)
else:
raise Exception("pretrained checkpoint type {} not supported".format(args.pretrained_checkpoint_type))
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
optimizer.reload_model_params()
timers("pretrained checkpoint").stop()
# Print setup timing.
......@@ -305,11 +300,13 @@ def finetune(
train_dataloader,
valid_dataloader,
end_of_epoch_callback,
process_non_loss_data_func,
)
# Or just evaluate.
else:
if end_of_epoch_callback is not None:
print_rank_0("evaluation only mode, setting epoch to -1")
end_of_epoch_callback(model, epoch=-1, output_predictions=True)
end_of_epoch_callback(model, epoch=-1)
print_rank_0("done :-)")
......@@ -28,32 +28,24 @@ sys.path.append(
)
from megatron import get_args
from megatron.initialize import initialize_megatron
from classification import main
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title="tasks")
group.add_argument(
"--epochs",
type=int,
default=None,
help="Number of finetunning epochs. Zero results in "
"evaluation only.",
)
group.add_argument(
"--pretrained-checkpoint",
type=str,
default=None,
help="Pretrained checkpoint used for finetunning.",
)
group.add_argument(
"--keep-last",
action="store_true",
help="Keep the last batch (maybe incomplete) in" "the data loader",
)
group.add_argument('--task', type=str, default='segment',
choices=['classify', 'segment_setr', 'segment_segformer'],
help='task name.')
group.add_argument("--epochs", type=int, default=None,
help="Number of finetunning epochs. Zero results in "
"evaluation only.")
group.add_argument('--pretrained-checkpoint-type', type=str, default='default',
choices=['default', 'external', 'constrastive'],
help='Type of pretrained checkpoint')
group.add_argument("--pretrained-checkpoint", type=str, default=None,
help="Pretrained checkpoint used for finetunning.")
group.add_argument('--seg-stride', type=int, default=None,
help='sliding window stride during evaluation')
return parser
......@@ -61,4 +53,14 @@ if __name__ == "__main__":
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
main()
if args.task == 'classify':
from tasks.vision.classification.classification import main
main()
elif args.task == 'segment_setr':
from tasks.vision.segmentation.finetune_setr import main
main()
elif args.task == 'segment_segformer':
from tasks.vision.segmentation.finetune_segformer import main
main()
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# code taken from
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/cityscapes.py
# modified it to change max label index from 255 to 19 (num_classes)
import torch
import json
import os
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
import numpy as np
from torchvision.datasets.utils import extract_archive, verify_str_arg, iterable_to_str
from torchvision.datasets import VisionDataset
from PIL import Image
from megatron import print_rank_0
class Cityscapes(VisionDataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
otherwise ``train``, ``train_extra`` or ``val``
mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
Examples:
Get semantic segmentation target
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type='semantic')
img, smnt = dataset[0]
Get multiple targets
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
Validate on the "coarse" set
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
target_type='semantic')
img, smnt = dataset[0]
"""
num_classes = 19
ignore_index = 19
color_table = torch.tensor(
[[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[70, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32],
[0, 0, 0]], dtype=torch.float, device='cuda')
# Based on https://github.com/mcordts/cityscapesScripts
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id',
'category', 'category_id', 'has_instances', 'ignore_in_eval', 'color'])
classes = [
CityscapesClass('unlabeled', 0, 19, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('ego vehicle', 1, 19, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('rectification border', 2, 19, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('out of roi', 3, 19, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('static', 4, 19, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('dynamic', 5, 19, 'void', 0, False, True, (111, 74, 0)),
CityscapesClass('ground', 6, 19, 'void', 0, False, True, (81, 0, 81)),
CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
CityscapesClass('parking', 9, 19, 'flat', 1, False, True, (250, 170, 160)),
CityscapesClass('rail track', 10, 19, 'flat', 1, False, True, (230, 150, 140)),
CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
CityscapesClass('guard rail', 14, 19, 'construction', 2, False, True, (180, 165, 180)),
CityscapesClass('bridge', 15, 19, 'construction', 2, False, True, (150, 100, 100)),
CityscapesClass('tunnel', 16, 19, 'construction', 2, False, True, (150, 120, 90)),
CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
CityscapesClass('polegroup', 18, 19, 'object', 3, False, True, (153, 153, 153)),
CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
CityscapesClass('caravan', 29, 19, 'vehicle', 7, True, True, (0, 0, 90)),
CityscapesClass('trailer', 30, 19, 'vehicle', 7, True, True, (0, 0, 110)),
CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
]
# label2trainid
label2trainid = { label.id : label.train_id for label in classes}
def __init__(
self,
root: str,
split: str = "train",
mode: str = "fine",
resolution: int = 1024,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
self.images_dir = os.path.join(self.root, 'leftImg8bit_trainvaltest/leftImg8bit', split)
self.targets_dir = os.path.join(self.root, 'gtFine_trainvaltest/gtFine', split)
self.split = split
self.resolution = resolution
self.images = []
self.targets = []
for city in sorted(os.listdir(self.images_dir)):
img_dir = os.path.join(self.images_dir, city)
target_dir = os.path.join(self.targets_dir, city)
for file_name in os.listdir(img_dir):
target_name = '{}_{}_labelIds.png'.format(file_name.split('_leftImg8bit')[0], self.mode)
self.images.append(os.path.join(img_dir, file_name))
self.targets.append(os.path.join(target_dir, target_name))
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
"""
image = Image.open(self.images[index]).convert('RGB')
target = Image.open(self.targets[index])
target = np.array(target)
target_copy = target.copy()
for k, v in Cityscapes.label2trainid.items():
binary_target = (target == k)
target_copy[binary_target] = v
target = target_copy
target = Image.fromarray(target.astype(np.uint8))
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
# len(self.images)
return len(self.images)
import random
import os
import math
import mmcv
import torch
import numpy as np
import torchvision.transforms as T
from torchvision import datasets
from torch.utils.data import Dataset
from megatron.data.autoaugment import ImageNetPolicy
from tasks.vision.segmentation.cityscapes import Cityscapes
import tasks.vision.segmentation.transforms as ET
from megatron.data.autoaugment import ImageNetPolicy
from megatron import get_args
from PIL import Image, ImageOps
class VitSegmentationJointTransform():
def __init__(self, train=True, resolution=None):
self.train = train
if self.train:
self.transform0 = ET.RandomSizeAndCrop(resolution)
self.transform1 = ET.RandomHorizontallyFlip()
def __call__(self, img, mask):
if self.train:
img, mask = self.transform0(img, mask)
img, mask = self.transform1(img, mask)
return img, mask
class VitSegmentationImageTransform():
def __init__(self, train=True, resolution=None):
args = get_args()
self.train = train
assert args.fp16 or args.bf16
self.data_type = torch.half if args.fp16 else torch.bfloat16
self.mean_std = args.mean_std
if self.train:
assert resolution is not None
self.transform = T.Compose([
ET.PhotoMetricDistortion(),
T.ToTensor(),
T.Normalize(*self.mean_std),
T.ConvertImageDtype(self.data_type)
])
else:
self.transform = T.Compose([
T.ToTensor(),
T.Normalize(*self.mean_std),
T.ConvertImageDtype(self.data_type)
])
def __call__(self, input):
output = self.transform(input)
return output
class VitSegmentationTargetTransform():
def __init__(self, train=True, resolution=None):
self.train = train
def __call__(self, input):
output = torch.from_numpy(np.array(input, dtype=np.int32)).long()
return output
class RandomSeedSegmentationDataset(Dataset):
def __init__(self,
dataset,
joint_transform,
image_transform,
target_transform):
args = get_args()
self.base_seed = args.seed
self.curr_seed = self.base_seed
self.dataset = dataset
self.joint_transform = joint_transform
self.image_transform = image_transform
self.target_transform = target_transform
def __len__(self):
return len(self.dataset)
def set_epoch(self, epoch):
self.curr_seed = self.base_seed + 100 * epoch
def __getitem__(self, idx):
seed = idx + self.curr_seed
img, mask = self.dataset[idx]
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
img, mask = self.joint_transform(img, mask)
img = self.image_transform(img)
mask = self.target_transform(mask)
return img, mask
def build_cityscapes_train_valid_datasets(data_path, image_size):
args = get_args()
args.num_classes = Cityscapes.num_classes
args.ignore_index = Cityscapes.ignore_index
args.color_table = Cityscapes.color_table
args.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_joint_transform = \
VitSegmentationJointTransform(train=True, resolution=image_size)
val_joint_transform = \
VitSegmentationJointTransform(train=False, resolution=image_size)
train_image_transform = \
VitSegmentationImageTransform(train=True, resolution=image_size)
val_image_transform = \
VitSegmentationImageTransform(train=False, resolution=image_size)
train_target_transform = \
VitSegmentationTargetTransform(train=True, resolution=image_size)
val_target_transform = \
VitSegmentationTargetTransform(train=False, resolution=image_size)
# training dataset
train_data = Cityscapes(
root=data_path[0],
split='train',
mode='fine',
resolution=image_size
)
train_data = RandomSeedSegmentationDataset(
train_data,
joint_transform=train_joint_transform,
image_transform=train_image_transform,
target_transform=train_target_transform)
# validation dataset
val_data = Cityscapes(
root=data_path[0],
split='val',
mode='fine',
resolution=image_size
)
val_data = RandomSeedSegmentationDataset(
val_data,
joint_transform=val_joint_transform,
image_transform=val_image_transform,
target_transform=val_target_transform)
return train_data, val_data
def build_train_valid_datasets(data_path, image_size):
return build_cityscapes_train_valid_datasets(data_path, image_size)
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import numpy as np
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import mpu, print_rank_0, print_rank_last
from tasks.vision.finetune_utils import finetune
from tasks.vision.finetune_utils import build_data_loader
from megatron.utils import average_losses_across_data_parallel_group
from megatron.schedules import get_forward_backward_func
from tasks.vision.segmentation.data import build_train_valid_datasets
from tasks.vision.segmentation.seg_models import SegformerSegmentationModel
from megatron.model.vision.utils import resize
def calculate_iou(hist_data):
acc = np.diag(hist_data).sum() / hist_data.sum()
acc_cls = np.diag(hist_data) / hist_data.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
divisor = hist_data.sum(axis=1) + hist_data.sum(axis=0) - \
np.diag(hist_data)
iu = np.diag(hist_data) / divisor
return iu, acc, acc_cls
def fast_hist(pred, gtruth, num_classes):
# mask indicates pixels we care about
mask = (gtruth >= 0) & (gtruth < num_classes)
# stretch ground truth labels by num_classes
# class 0 -> 0
# class 1 -> 19
# class 18 -> 342
#
# TP at 0 + 0, 1 + 1, 2 + 2 ...
#
# TP exist where value == num_classes*class_id + class_id
# FP = row[class].sum() - TP
# FN = col[class].sum() - TP
hist = np.bincount(num_classes * gtruth[mask].astype(int) + pred[mask],
minlength=num_classes ** 2)
hist = hist.reshape(num_classes, num_classes)
return hist
def segmentation():
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
return train_ds, valid_ds
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
model = SegformerSegmentationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
print_rank_0("model = {}".format(model))
return model
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
masks = batch[1].cuda().contiguous()
return images, masks
def calculate_weight(masks, num_classes):
bins = torch.histc(masks, bins=num_classes, min=0.0, max=num_classes)
hist_norm = bins.float()/bins.sum()
hist = ((bins != 0).float() * (1. - hist_norm)) + 1.0
return hist
def cross_entropy_loss_func(images, masks, output_tensor,
non_loss_data=False):
args = get_args()
ignore_index = args.ignore_index
color_table = args.color_table
logits = output_tensor.contiguous().float()
logits = resize(logits, size=masks.shape[1:],
mode='bilinear', align_corners=False)
# Cross-entropy loss.
# weight = calculate_weight(masks, num_classes)
loss = F.cross_entropy(logits, masks, ignore_index=ignore_index)
if not non_loss_data:
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
else:
seg_mask = logits.argmax(dim=1)
output_mask = F.embedding(seg_mask, color_table).permute(0, 3, 1, 2)
gt_mask = F.embedding(masks, color_table).permute(0, 3, 1, 2)
return torch.cat((images, output_mask, gt_mask), dim=2), loss
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers("batch generator").start()
import types
if isinstance(batch, types.GeneratorType):
batch_ = next(batch)
else:
batch_ = batch
images, masks = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, images, masks)
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
forward_backward_func = get_forward_backward_func()
for m in model:
m.eval()
def loss_func(labels, output_tensor):
args = get_args()
logits = output_tensor
logits = resize(logits, size=labels.shape[1:],
mode='bilinear', align_corners=False)
loss_dict = {}
# Compute the correct answers.
probs = logits.contiguous().float().softmax(dim=1)
max_probs, preds = torch.max(probs, 1)
preds = preds.cpu().numpy()
performs = fast_hist(preds.flatten(),
labels.cpu().numpy().flatten(),
args.ignore_index)
loss_dict['performs'] = performs
return 0, loss_dict
# defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
# Forward model.
output_tensor = model(images)
return output_tensor, partial(loss_func, labels)
with torch.no_grad():
# For all the batches in the dataset.
performs = None
for _, batch in enumerate(dataloader):
loss_dicts = forward_backward_func(correct_answers_forward_step,
batch, model,
optimizer=None,
timers=None,
forward_only=True)
for loss_dict in loss_dicts:
if performs is None:
performs = loss_dict['performs']
else:
performs += loss_dict['performs']
for m in model:
m.train()
# Reduce.
if mpu.is_pipeline_last_stage():
performs_tensor = torch.cuda.FloatTensor(performs)
torch.distributed.all_reduce(performs_tensor,
group=mpu.get_data_parallel_group())
hist = performs_tensor.cpu().numpy()
iu, acc, acc_cls = calculate_iou(hist)
miou = np.nanmean(iu)
return iu, miou
def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
dataloader = build_data_loader(
valid_ds,
args.micro_batch_size,
num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1),
shuffle=False
)
def metrics_func(model, epoch):
print_rank_0("calculating metrics ...")
iou, miou = calculate_correct_answers(model, dataloader, epoch)
print_rank_last(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %".format(epoch, iou, miou*100.0)
)
return metrics_func
def dump_output_data(data, iteration, writer):
for (output_tb, loss) in data:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer.add_images("image-outputseg-realseg", output_tb,
global_step=None, walltime=None,
dataformats='NCHW')
"""Finetune/evaluate."""
finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
process_non_loss_data_func=dump_output_data,
end_of_epoch_callback_provider=accuracy_func_provider,
)
def main():
segmentation()
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import mpu, print_rank_0, print_rank_last
from tasks.vision.finetune_utils import finetune
from tasks.vision.finetune_utils import build_data_loader
from megatron.utils import average_losses_across_data_parallel_group
from megatron.schedules import get_forward_backward_func
from tasks.vision.segmentation.metrics import CFMatrix
from tasks.vision.segmentation.data import build_train_valid_datasets
from tasks.vision.segmentation.seg_models import SetrSegmentationModel
from tasks.vision.segmentation.utils import slidingcrops, slidingjoins
def segmentation():
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
return train_ds, valid_ds
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
return SetrSegmentationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
masks = batch[1].cuda().contiguous()
return images, masks
def calculate_weight(masks, num_classes):
bins = torch.histc(masks, bins=num_classes, min=0.0, max=num_classes)
hist_norm = bins.float()/bins.sum()
hist = ((bins != 0).float() * (1. - hist_norm)) + 1.0
return hist
def cross_entropy_loss_func(images, masks, output_tensor, non_loss_data=False):
args = get_args()
ignore_index = args.ignore_index
color_table = args.color_table
weight = calculate_weight(masks, args.num_classes)
logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, masks, weight=weight, ignore_index=ignore_index)
if not non_loss_data:
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
else:
seg_mask = logits.argmax(dim=1)
output_mask = F.embedding(seg_mask, color_table).permute(0, 3, 1, 2)
gt_mask = F.embedding(masks, color_table).permute(0, 3, 1, 2)
return torch.cat((images, output_mask, gt_mask), dim=2), loss
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
args = get_args()
timers = get_timers()
# Get the batch.
timers("batch generator").start()
import types
if isinstance(batch, types.GeneratorType):
batch_ = next(batch)
else:
batch_ = batch
images, masks = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
if not model.training:
images, masks, _, _ = slidingcrops(images, masks)
#print_rank_0("images size = {}".format(images.size()))
if not model.training:
output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)])
else:
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, images, masks)
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
forward_backward_func = get_forward_backward_func()
for m in model:
m.eval()
def loss_func(labels, slices_info, img_size, output_tensor):
args = get_args()
logits = output_tensor
loss_dict = {}
# Compute the correct answers.
probs = logits.contiguous().float().softmax(dim=1)
max_probs, preds = torch.max(probs, 1)
preds = preds.int()
preds, labels = slidingjoins(preds, max_probs, labels, slices_info, img_size)
_, performs = CFMatrix()(preds, labels, args.ignore_index)
loss_dict['performs'] = performs
return 0, loss_dict
# defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
args = get_args()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
assert not model.training
images, labels, slices_info, img_size = slidingcrops(images, labels)
# Forward model.
output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)])
return output_tensor, partial(loss_func, labels, slices_info, img_size)
with torch.no_grad():
# For all the batches in the dataset.
performs = None
for _, batch in enumerate(dataloader):
loss_dicts = forward_backward_func(correct_answers_forward_step,
batch, model,
optimizer=None,
timers=None,
forward_only=True)
for loss_dict in loss_dicts:
if performs is None:
performs = loss_dict['performs']
else:
performs += loss_dict['performs']
for m in model:
m.train()
# Reduce.
if mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(performs,
group=mpu.get_data_parallel_group())
# Print on screen.
# performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn]
true_positive = performs[:, 0]
false_positive = performs[:, 1]
false_negative = performs[:, 3]
iou = true_positive / (true_positive + false_positive + false_negative)
miou = iou[~torch.isnan(iou)].mean()
return iou.tolist(), miou.item()
def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
dataloader = build_data_loader(
valid_ds,
args.micro_batch_size,
num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1),
shuffle=False
)
def metrics_func(model, epoch):
print_rank_0("calculating metrics ...")
iou, miou = calculate_correct_answers(model, dataloader, epoch)
print_rank_last(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %".format(epoch, iou, miou*100.0)
)
return metrics_func
def dump_output_data(data, iteration, writer):
for (output_tb, loss) in data:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer.add_images("image-outputseg-realseg", output_tb,
global_step=None, walltime=None,
dataformats='NCHW')
"""Finetune/evaluate."""
finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
process_non_loss_data_func=dump_output_data,
end_of_epoch_callback_provider=accuracy_func_provider,
)
def main():
segmentation()
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
#copyright (c) go-hiroaki & Chokurei
#email: guangmingwu2010@gmail.com
# guozhilingty@gmail.com
#
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
eps = 1e-6
def _binarize(y_data, threshold):
"""
args:
y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
threshold : [float] [0.0, 1.0]
return 4-d binarized y_data
"""
y_data[y_data < threshold] = 0.0
y_data[y_data >= threshold] = 1.0
return y_data
def _argmax(y_data, dim):
"""
args:
y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
dim : int
return 3-d [int] y_data
"""
return torch.argmax(y_data, dim).int()
def _get_tp(y_pred, y_true):
"""
args:
y_true : [int] 3-d in [batch_size, img_rows, img_cols]
y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
return [float] true_positive
"""
return torch.sum(y_true * y_pred).float()
def _get_fp(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_positive
"""
return torch.sum((1 - y_true) * y_pred).float()
def _get_tn(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] true_negative
"""
return torch.sum((1 - y_true) * (1 - y_pred)).float()
def _get_fn(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_negative
"""
return torch.sum(y_true * (1 - y_pred)).float()
def _get_weights(y_true, nb_ch):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
nb_ch : int
return [float] weights
"""
batch_size, img_rows, img_cols = y_true.shape
pixels = batch_size * img_rows * img_cols
weights = [torch.sum(y_true==ch).item() / pixels for ch in range(nb_ch)]
return weights
class CFMatrix(object):
def __init__(self, des=None):
self.des = des
def __repr__(self):
return "ConfusionMatrix"
def __call__(self, y_pred, y_true, ignore_index, threshold=0.5):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return confusion matrix
"""
batch_size, img_rows, img_cols = y_pred.shape
chs = ignore_index
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_tn = _get_tn(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
mperforms = [nb_tp, nb_fp, nb_tn, nb_fn]
performs = None
else:
performs = torch.zeros(chs, 4).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_false_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_false_ch[torch.logical_and((y_true != ch), (y_true != ignore_index))] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = torch.sum(y_false_ch * y_pred_ch).float()
nb_tn = torch.sum(y_false_ch * (1 - y_pred_ch)).float()
nb_fn = _get_fn(y_pred_ch, y_true_ch)
performs[int(ch), :] = torch.FloatTensor([nb_tp, nb_fp, nb_tn, nb_fn])
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class OAAcc(object):
def __init__(self, des="Overall Accuracy"):
self.des = des
def __repr__(self):
return "OAcc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (tp+tn)/total
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
nb_tp_tn = torch.sum(y_true == y_pred).float()
mperforms = nb_tp_tn / (batch_size * img_rows * img_cols)
performs = None
return mperforms, performs
class Precision(object):
def __init__(self, des="Precision"):
self.des = des
def __repr__(self):
return "Prec"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fp)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
mperforms = nb_tp / (nb_tp + nb_fp + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
performs[int(ch)] = nb_tp / (nb_tp + nb_fp + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Recall(object):
def __init__(self, des="Recall"):
self.des = des
def __repr__(self):
return "Reca"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fn)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
mperforms = nb_tp / (nb_tp + nb_fn + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
performs[int(ch)] = nb_tp / (nb_tp + nb_fn + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class F1Score(object):
def __init__(self, des="F1Score"):
self.des = des
def __repr__(self):
return "F1Sc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return 2*precision*recall/(precision+recall)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
_precision = nb_tp / (nb_tp + nb_fp + esp)
_recall = nb_tp / (nb_tp + nb_fn + esp)
mperforms = 2 * _precision * _recall / (_precision + _recall + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
_precision = nb_tp / (nb_tp + nb_fp + esp)
_recall = nb_tp / (nb_tp + nb_fn + esp)
performs[int(ch)] = 2 * _precision * \
_recall / (_precision + _recall + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Kappa(object):
def __init__(self, des="Kappa"):
self.des = des
def __repr__(self):
return "Kapp"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (Po-Pe)/(1-Pe)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_tn = _get_tn(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
nb_total = nb_tp + nb_fp + nb_tn + nb_fn
Po = (nb_tp + nb_tn) / nb_total
Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn) +
(nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
mperforms = (Po - Pe) / (1 - Pe + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
nb_tn = _get_tn(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
nb_total = nb_tp + nb_fp + nb_tn + nb_fn
Po = (nb_tp + nb_tn) / nb_total
Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn)
+ (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
performs[int(ch)] = (Po - Pe) / (1 - Pe + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Jaccard(object):
def __init__(self, des="Jaccard"):
self.des = des
def __repr__(self):
return "Jacc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return intersection / (sum-intersection)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
_intersec = torch.sum(y_true * y_pred).float()
_sum = torch.sum(y_true + y_pred).float()
mperforms = _intersec / (_sum - _intersec + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
_intersec = torch.sum(y_true_ch * y_pred_ch).float()
_sum = torch.sum(y_true_ch + y_pred_ch).float()
performs[int(ch)] = _intersec / (_sum - _intersec + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class MSE(object):
def __init__(self, des="Mean Square Error"):
self.des = des
def __repr__(self):
return "MSE"
def __call__(self, y_pred, y_true, dim=1, threshold=None):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return mean_squared_error, smaller the better
"""
if threshold:
y_pred = _binarize(y_pred, threshold)
return torch.mean((y_pred - y_true) ** 2)
class PSNR(object):
def __init__(self, des="Peak Signal to Noise Ratio"):
self.des = des
def __repr__(self):
return "PSNR"
def __call__(self, y_pred, y_true, dim=1, threshold=None):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return PSNR, larger the better
"""
if threshold:
y_pred = _binarize(y_pred, threshold)
mse = torch.mean((y_pred - y_true) ** 2)
return 10 * torch.log10(1 / mse)
class SSIM(object):
'''
modified from https://github.com/jorge-pessoa/pytorch-msssim
'''
def __init__(self, des="structural similarity index"):
self.des = des
def __repr__(self):
return "SSIM"
def gaussian(self, w_size, sigma):
gauss = torch.Tensor([math.exp(-(x - w_size//2)**2/float(2*sigma**2)) for x in range(w_size)])
return gauss/gauss.sum()
def create_window(self, w_size, channel=1):
_1D_window = self.gaussian(w_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, w_size, w_size).contiguous()
return window
def __call__(self, y_pred, y_true, w_size=11, size_average=True, full=False):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
w_size : int, default 11
size_average : boolean, default True
full : boolean, default False
return ssim, larger the better
"""
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if torch.max(y_pred) > 128:
max_val = 255
else:
max_val = 1
if torch.min(y_pred) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
padd = 0
(_, channel, height, width) = y_pred.size()
window = self.create_window(w_size, channel=channel).to(y_pred.device)
mu1 = F.conv2d(y_pred, window, padding=padd, groups=channel)
mu2 = F.conv2d(y_true, window, padding=padd, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(y_pred * y_pred, window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(y_true * y_true, window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(y_pred * y_true, window, padding=padd, groups=channel) - mu1_mu2
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
class AE(object):
"""
Modified from matlab : colorangle.m, MATLAB V2019b
angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
angle = 180 / pi * angle;
"""
def __init__(self, des='average Angular Error'):
self.des = des
def __repr__(self):
return "AE"
def __call__(self, y_pred, y_true):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
return average AE, smaller the better
"""
dotP = torch.sum(y_pred * y_true, dim=1)
Norm_pred = torch.sqrt(torch.sum(y_pred * y_pred, dim=1))
Norm_true = torch.sqrt(torch.sum(y_true * y_true, dim=1))
ae = 180 / math.pi * torch.acos(dotP / (Norm_pred * Norm_true + eps))
return ae.mean(1).mean(1)
if __name__ == "__main__":
for ch in [3, 1]:
batch_size, img_row, img_col = 1, 224, 224
y_true = torch.rand(batch_size, ch, img_row, img_col)
noise = torch.zeros(y_true.size()).data.normal_(0, std=0.1)
y_pred = y_true + noise
for cuda in [False, True]:
if cuda:
y_pred = y_pred.cuda()
y_true = y_true.cuda()
print('#'*20, 'Cuda : {} ; size : {}'.format(cuda, y_true.size()))
########### similarity metrics
metric = MSE()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = PSNR()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = SSIM()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = LPIPS(cuda)
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = AE()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
########### accuracy metrics
metric = OAAcc()
maccu, accu = metric(y_pred, y_true)
print('mAccu:', maccu, 'Accu', accu)
metric = Precision()
mprec, prec = metric(y_pred, y_true)
print('mPrec:', mprec, 'Prec', prec)
metric = Recall()
mreca, reca = metric(y_pred, y_true)
print('mReca:', mreca, 'Reca', reca)
metric = F1Score()
mf1sc, f1sc = metric(y_pred, y_true)
print('mF1sc:', mf1sc, 'F1sc', f1sc)
metric = Kappa()
mkapp, kapp = metric(y_pred, y_true)
print('mKapp:', mkapp, 'Kapp', kapp)
metric = Jaccard()
mjacc, jacc = metric(y_pred, y_true)
print('mJacc:', mjacc, 'Jacc', jacc)
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import einops
import torch
import apex
import torch.nn.functional as F
from megatron import get_args
from megatron.model import LayerNorm
from megatron.model.module import MegatronModule
from megatron.model.vision.utils import resize
class SetrSegmentationHead(MegatronModule):
def __init__(self, hidden_size, num_classes):
super(SetrSegmentationHead, self).__init__()
args = get_args()
self.hidden_size = hidden_size
self.num_classes = num_classes
self.img_h = args.img_h
self.img_w = args.img_w
self.patch_dim = args.patch_dim
self.layernorm = LayerNorm(hidden_size, eps=args.layernorm_epsilon)
self.conv_0 = torch.nn.Conv2d(hidden_size, hidden_size,
1, 1, bias=False)
self.norm_0 = apex.parallel.SyncBatchNorm(hidden_size)
self.conv_1 = torch.nn.Conv2d(hidden_size, num_classes, 1, 1)
def to_2D(self, x):
n, hw, c = x.shape
h = self.img_h // self.patch_dim
w = self.img_w // self.patch_dim
assert(hw == h * w)
x = x.transpose(1, 2).reshape(n, c, h, w)
return x
def forward(self, hidden_states):
# [b c h w]
hidden_states = self.layernorm(hidden_states)
hidden_states = self.to_2D(hidden_states)
hidden_states = self.conv_0(hidden_states)
hidden_states = self.norm_0(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.conv_1(hidden_states)
# [b c h w]
result = F.interpolate(hidden_states,
size=(self.img_h, self.img_w),
mode='bilinear')
return result
class MLP(torch.nn.Module):
"""
Linear Embedding
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = torch.nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class SegformerSegmentationHead(MegatronModule):
def __init__(self, feature_strides, in_channels,
embedding_dim, dropout_ratio):
super(SegformerSegmentationHead, self).__init__()
assert len(feature_strides) == len(in_channels)
assert min(feature_strides) == feature_strides[0]
args = get_args()
self.feature_strides = feature_strides
self.in_channels = in_channels
self.embedding_dim = embedding_dim
self.num_classes = args.num_classes
self.dropout_ratio = dropout_ratio
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = \
self.in_channels
self.linear_c4 = MLP(input_dim=c4_in_channels,
embed_dim=self.embedding_dim)
self.linear_c3 = MLP(input_dim=c3_in_channels,
embed_dim=self.embedding_dim)
self.linear_c2 = MLP(input_dim=c2_in_channels,
embed_dim=self.embedding_dim)
self.linear_c1 = MLP(input_dim=c1_in_channels,
embed_dim=self.embedding_dim)
self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4,
self.embedding_dim, 1, 1)
self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim)
self.dropout = torch.nn.Dropout2d(self.dropout_ratio)
self.linear_pred = torch.nn.Conv2d(self.embedding_dim,
self.num_classes,
kernel_size=1)
def forward(self, inputs):
c1, c2, c3, c4 = inputs
############## MLP decoder on C1-C4 ###########
n, _, h, w = c4.shape
_c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
_c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
_c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
_c = self.conv_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
x = self.norm(_c)
x = F.relu(x, inplace=True)
x = self.dropout(x)
x = self.linear_pred(x)
return x
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import einops
import torch
import apex
import torch.nn.functional as F
from megatron import get_args
from megatron.model.module import MegatronModule
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.vision.mit_backbone import mit_b3, mit_b5
from tasks.vision.segmentation.seg_heads import SetrSegmentationHead, SegformerSegmentationHead
class SetrSegmentationModel(MegatronModule):
def __init__(self,
num_classes,
pre_process=True,
post_process=True):
super(SetrSegmentationModel, self).__init__()
args = get_args()
assert post_process & pre_process
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.backbone = VitBackbone(
pre_process=pre_process,
post_process=post_process,
class_token=False,
post_layer_norm=False,
drop_path_rate=0.1
)
self.head = SetrSegmentationHead(
self.hidden_size,
self.num_classes
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
# [b hw c]
hidden_states = self.backbone(input)
result_final = self.head(hidden_states)
return result_final
class SegformerSegmentationModel(MegatronModule):
def __init__(self,
num_classes,
pre_process=True,
post_process=True):
super(SegformerSegmentationModel, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.pre_process = pre_process
self.post_process = post_process
self.backbone = mit_b5()
self.head = SegformerSegmentationHead(
feature_strides=[4, 8, 16, 32],
in_channels=[64, 128, 320, 512],
embedding_dim=768,
dropout_ratio=0.1
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
# [b hw c]
hidden_states = self.backbone(input)
hidden_states = self.head(hidden_states)
return hidden_states
# Copyright (c) 2020 The MMSegmenation Authors.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import random
import os
import math
import mmcv
import torch
import numpy as np
import torchvision.transforms as T
from torchvision import datasets
from torch.utils.data import Dataset
from megatron import print_rank_0
from megatron import get_args
from PIL import Image, ImageOps, ImageEnhance
import torchvision.transforms as torch_tr
def _is_pil_image(img):
return isinstance(img, Image.Image)
class PhotoMetricDistortion(object):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def convert(self, img, alpha=1, beta=0):
"""Multiple with alpha and add beat with clip."""
img = img.astype(np.float32) * alpha + beta
img = np.clip(img, 0, 255)
return img.astype(np.uint8)
def brightness(self, img):
"""Brightness distortion."""
if random.randint(0, 1):
return self.convert(
img,
beta=random.uniform(-self.brightness_delta,
self.brightness_delta))
return img
def contrast(self, img):
"""Contrast distortion."""
if random.randint(0, 1):
return self.convert(
img,
alpha=random.uniform(self.contrast_lower, self.contrast_upper))
return img
def saturation(self, img):
"""Saturation distortion."""
if random.randint(0, 1):
img = mmcv.bgr2hsv(img)
img[:, :, 1] = self.convert(
img[:, :, 1],
alpha=random.uniform(self.saturation_lower,
self.saturation_upper))
img = mmcv.hsv2bgr(img)
return img
def hue(self, img):
"""Hue distortion."""
if random.randint(0, 1):
img = mmcv.bgr2hsv(img)
img[:, :,
0] = (img[:, :, 0].astype(int) +
random.randint(-self.hue_delta, self.hue_delta)) % 180
img = mmcv.hsv2bgr(img)
return img
def __call__(self, img):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
img = np.array(img)
# random brightness
img = self.brightness(img)
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(0, 1)
if mode == 1:
img = self.contrast(img)
# random saturation
img = self.saturation(img)
# random hue
img = self.hue(img)
# random contrast
if mode == 0:
img = self.contrast(img)
img = Image.fromarray(img.astype(np.uint8)).convert('RGB')
return img
class RandomCrop(object):
"""
Take a random crop from the image.
First the image or crop size may need to be adjusted if the incoming image
is too small...
If the image is smaller than the crop, then:
the image is padded up to the size of the crop
unless 'nopad', in which case the crop size is shrunk to fit the image
A random crop is taken such that the crop fits within the image.
if cfg.DATASET.TRANSLATION_AUG_FIX is set, we insure that there's always
translation randomness of at least that value around the image.
if image < crop_size:
# slide crop within image, random offset
else:
# slide image within crop
"""
def __init__(self, crop_size):
args = get_args()
self.size = crop_size
self.cat_max_ratio = 0.75
self.ignore_index = args.ignore_index
self.pad_color = (0, 0, 0)
def get_crop_bbox(self, img):
"""Randomly get a crop bounding box."""
img_w, img_h = img.size
target_h, target_w = self.size #[H W]
margin_h = max(img_h - target_h, 0)
margin_w = max(img_w - target_w, 0)
offset_h = random.randint(0, margin_h)
offset_w = random.randint(0, margin_w)
crop_y1, crop_y2 = offset_h, offset_h + target_h
crop_x1, crop_x2 = offset_w, offset_w + target_w
return crop_y1, crop_y2, crop_x1, crop_x2
def crop(self, img, crop_bbox):
"""Crop from ``img``"""
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
img = img.crop((crop_x1, crop_y1, crop_x2, crop_y2))
return img
@staticmethod
def crop_in_image(target_w, target_h, w, h, img, mask):
if w == target_w:
x1 = 0
else:
x1 = random.randint(0, w - target_w)
if h == target_h:
y1 = 0
else:
y1 = random.randint(0, h - target_h)
return [img.crop((x1, y1, x1 + target_w, y1 + target_h)),
mask.crop((x1, y1, x1 + target_w, y1 + target_h))]
def __call__(self, img, mask):
w, h = img.size
target_h, target_w = self.size # ASSUME H, W
if w == target_w and h == target_h:
return img, mask
# Pad image if image < crop
if target_h > h:
pad_h = (target_h - h) // 2 + 1
else:
pad_h = 0
if target_w > w:
pad_w = (target_w - w) // 2 + 1
else:
pad_w = 0
border = (pad_w, pad_h, pad_w, pad_h)
if pad_h or pad_w:
img = ImageOps.expand(img, border=border, fill=(0, 0, 0))
mask = ImageOps.expand(mask, border=border, fill=self.ignore_index)
w, h = img.size
crop_bbox = self.get_crop_bbox(img)
if self.cat_max_ratio < 1.:
# Repeat 10 times
for _ in range(10):
seg_temp = self.crop(mask, crop_bbox)
labels, cnt = np.unique(seg_temp, return_counts=True)
cnt = cnt[labels != self.ignore_index]
if len(cnt) > 1 and np.max(cnt) / np.sum(
cnt) < self.cat_max_ratio:
break
crop_bbox = self.get_crop_bbox(img)
# crop the image
img = self.crop(img, crop_bbox)
# crop semantic seg
mask = self.crop(mask, crop_bbox)
assert(img.size[0] == self.size[1] and img.size[1] == self.size[0])
return img, mask
class RandomSizeAndCrop(object):
def __init__(self,
crop_size,
scale_min=0.5,
scale_max=2.0):
self.crop = RandomCrop(crop_size)
self.scale_min = scale_min
self.scale_max = scale_max
def __call__(self, img, mask):
scale_amt = random.uniform(self.scale_min, self.scale_max)
w, h = [int(i * scale_amt) for i in img.size]
resized_img = img.resize((w, h), Image.BICUBIC)
resized_mask = mask.resize((w, h), Image.NEAREST)
img, mask = self.crop(resized_img, resized_mask)
return img, mask
class RandomHorizontallyFlip(object):
def __call__(self, img, mask):
if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(
Image.FLIP_LEFT_RIGHT)
return img, mask
def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
PIL Image: Brightness adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image: Contrast adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an image.
Args:
img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image: Saturation adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
def adjust_hue(img, hue_factor):
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
Args:
img (PIL Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
PIL Image: Hue adjusted image.
"""
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}:
return img
h, s, v = img.convert('HSV').split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
return img
class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image.
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
@staticmethod
def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms = []
if brightness > 0:
brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
transforms.append(
torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor)))
if contrast > 0:
contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
transforms.append(
torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor)))
if saturation > 0:
saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
transforms.append(
torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor)))
if hue > 0:
hue_factor = np.random.uniform(-hue, hue)
transforms.append(
torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor)))
np.random.shuffle(transforms)
transform = torch_tr.Compose(transforms)
return transform
def __call__(self, img):
"""
Args:
img (PIL Image): Input image.
Returns:
PIL Image: Color jittered image.
"""
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)
return transform(img)
import math
import torch
import numpy as np
from megatron import get_args
def slidingcrops(img, mask):
# img: [b c h w]
# mask: [b h w]
args = get_args()
assert args.img_h == args.img_w
crop_size = args.img_h
stride = args.seg_stride
ignore_index = args.ignore_index
n, c, h, w = img.shape
assert h >= crop_size
assert w >= crop_size
long_size = max(h, w)
img_slices, mask_slices, slices_info = [], [], []
if long_size > crop_size:
assert stride <= crop_size
h_step_num = int(math.ceil((h - crop_size) / float(stride))) + 1
w_step_num = int(math.ceil((w - crop_size) / float(stride))) + 1
for yy in range(h_step_num):
for xx in range(w_step_num):
sy, sx = yy * stride, xx * stride
ey, ex = sy + crop_size, sx + crop_size
img_sub = img[:, :, sy: ey, sx: ex]
mask_sub = mask[:, sy: ey, sx: ex]
# padding
sub_h, sub_w = img_sub.shape[2:]
pad_h = max(crop_size - sub_h, 0)
pad_w = max(crop_size - sub_w, 0)
img_sub = torch.nn.functional.pad(img_sub, pad=(0, pad_w, 0, pad_h), value=ignore_index)
mask_sub = torch.nn.functional.pad(mask_sub, pad=(0, pad_w, 0, pad_h))
img_slices.append(img_sub)
mask_slices.append(mask_sub)
slices_info.append([sy, ey, sx, ex, sub_h, sub_w])
return torch.cat(img_slices), torch.cat(mask_slices), slices_info, (h, w)
else:
return img, mask, [[0, h, 0, w, h, w]], (h, w)
def slidingjoins(preds, probs, labels, slices_info, img_size):
args = get_args()
num_slices = len(slices_info)
if num_slices == 1:
return preds, labels
h, w = img_size
split_size = args.micro_batch_size
preds_split = torch.split(preds, split_size)
probs_split = torch.split(probs, split_size)
labels_split = torch.split(labels, split_size)
assert(len(preds_split) == num_slices)
total_max_probs = torch.zeros((split_size, h, w), dtype=torch.float, device='cuda')
total_preds = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda')
total_labels = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda')
for i in range(num_slices):
sy, ey, sx, ex, sub_h, sub_w = slices_info[i]
assert sy + sub_h <= h
assert sx + sub_w <= w
curr_max_probs = total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w]
curr_preds = total_preds[:, sy:sy + sub_h, sx:sx + sub_w]
local_max_probs = probs_split[i][:, :sub_h, : sub_w]
local_preds = preds_split[i][:, :sub_h, :sub_w]
result_max_probs = torch.maximum(curr_max_probs, local_max_probs)
result_preds = torch.where(curr_max_probs >= local_max_probs, curr_preds, local_preds)
total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w] = result_max_probs
total_preds[:, sy:sy + sub_h, sx:sx + sub_w] = result_preds
total_labels[:, sy:sy + sub_h, sx:sx + sub_w] = labels_split[i][0, :sub_h, :sub_w]
return total_preds, total_labels
import json
import os
import sys
import types
import torch
def add_arguments(parser):
group = parser.add_argument_group(title='Megatron loader')
group.add_argument('--true-vocab-size', type=int, default=None,
help='original size of vocab, if specified will trim padding from embedding table.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of deepspeed repository')
def _load_checkpoint(queue, args):
# Search in directory above this
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
try:
from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import ModelType, module
from megatron import mpu, fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
queue.put("exit")
exit(1)
# We want all arguments to come from us
sys.argv = ['script.py',
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--no-initialization',
'--load', args.load_dir
]
margs = parse_args()
margs = load_args_from_checkpoint(margs)
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size
margs = validate_args(margs)
def check_for_arg(arg_name):
if getattr(margs, arg_name, None) is None:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
print(f"Arguments: {margs}")
queue.put("exit")
exit(1)
check_for_arg('tensor_model_parallel_size')
check_for_arg('pipeline_model_parallel_size')
check_for_arg('num_layers')
check_for_arg('hidden_size')
check_for_arg('seq_length')
check_for_arg('num_attention_heads')
check_for_arg('max_position_embeddings')
check_for_arg('tokenizer_type')
check_for_arg('iteration')
check_for_arg('bert_binary_head')
check_for_arg('params_dtype')
# Determine how to make our models
if args.model_type == 'GPT':
from pretrain_gpt import model_provider
margs.model_type = ModelType.encoder_or_decoder
elif args.model_type == 'BERT':
from pretrain_bert import model_provider
margs.model_type = ModelType.encoder_or_decoder
else:
raise Exception(f'unrecognized model type: {args.model_type}')
# supress warning about torch.distributed not being initialized
module.MegatronModule.embedding_warning_printed = True
consumed_train_samples = None
consumed_valid_samples = None
def get_models(count, dtype, pre_process, post_process):
nonlocal consumed_train_samples
nonlocal consumed_valid_samples
models = []
for rank in range(count):
mpu.initialize.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
load_checkpoint(model_, None, None)
assert(len(model_) == 1)
model_ = model_[0]
if consumed_train_samples is not None:
assert(margs.consumed_train_samples == consumed_train_samples)
else:
consumed_train_samples = margs.consumed_train_samples
if consumed_valid_samples is not None:
assert(margs.consumed_valid_samples == consumed_valid_samples)
else:
consumed_valid_samples = margs.consumed_valid_samples
models.append(model_)
return models
if margs.num_layers_per_virtual_pipeline_stage is not None:
print("Model with an interleaved pipeline schedule are not yet supported.")
queue.put("exit")
exit(1)
set_global_variables(margs)
mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs)
# Get true (non-padded) vocab size
if args.true_vocab_size is not None:
true_vocab_size = args.true_vocab_size
elif args.vocab_file is not None:
vocab = json.load(open(args.vocab_file))
true_vocab_size = len(vocab)
if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size:
print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.")
queue.put("exit")
exit(1)
else:
true_vocab_size = None
# short aliases
tp_size = margs.tensor_model_parallel_size
pp_size = margs.pipeline_model_parallel_size
# metadata
md = types.SimpleNamespace()
md.model_type = args.model_type
md.num_layers = margs.num_layers
md.hidden_size = margs.hidden_size
md.seq_length = margs.seq_length
md.num_attention_heads = margs.num_attention_heads
md.max_position_embeddings = margs.max_position_embeddings
md.tokenizer_type = margs.tokenizer_type
md.iteration = margs.iteration
md.params_dtype = margs.params_dtype
md.bert_binary_head = margs.bert_binary_head
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.true_vocab_size = true_vocab_size
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
# Get first pipe stage
mpu.initialize.set_pipeline_model_parallel_rank(0)
post_process = pp_size == 1
models = get_models(tp_size, md.params_dtype, True, post_process)
md.consumed_train_samples = consumed_train_samples
md.consumed_valid_samples = consumed_valid_samples
queue.put(md)
def queue_put(name, msg):
print(f"sending {name}")
msg["name"] = name
queue.put(msg)
# Send embeddings
message = {
"position embeddings": models[0].language_model.embedding.position_embeddings.weight.data,
"word embeddings": torch.cat(
[models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
dim = 0)
}
queue_put("embeddings", message)
total_layer_num = 0
for pp_rank in range(pp_size):
if pp_rank > 0:
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == pp_size - 1
models = get_models(tp_size, md.params_dtype, False, post_process)
for layer_num in range(len(models[0].language_model.encoder.layers)):
message = {}
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
message["input layernorm weight"] = layer.input_layernorm.weight.data
message["input layernorm bias"] = layer.input_layernorm.bias.data
message["dense bias"] = layer.self_attention.dense.bias.data
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer
qkv_weight = []
qkv_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
for tp_rank, model in enumerate(models):
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
# concat them
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
queue_put(f"transformer layer {total_layer_num}", message)
total_layer_num = total_layer_num + 1
# Send final layernorm from tp_rank 0
message = {
"weight": models[0].language_model.encoder.final_layernorm.weight.data,
"bias": models[0].language_model.encoder.final_layernorm.bias.data
}
queue_put("final layernorm", message)
# Send BERT lm head and binary head if it exists
if md.model_type == 'BERT':
print("Sending LM Pooler")
message = {
"weight": models[0].language_model.pooler.dense.weight.data,
"bias": models[0].language_model.pooler.dense.bias.data
}
queue_put("pooler", message)
message = {
"dense weight": models[0].lm_head.dense.weight.data,
"dense bias": models[0].lm_head.dense.bias.data,
"layernorm weight": models[0].lm_head.layernorm.weight.data,
"layernorm bias": models[0].lm_head.layernorm.bias.data
}
queue_put("lm head", message)
if md.bert_binary_head:
print("Sending BERT Binary head")
queue.put("binary head")
message = {
"weight": models[0].binary_head.weight.data,
"bias": models[0].binary_head.bias.data
}
queue_put("binary head", message)
queue.put("done")
def load_checkpoint(queue, args):
try:
_load_checkpoint(queue, args)
except:
queue.put("exit")
raise
import argparse
from collections.abc import Mapping
import concurrent.futures
import os
import sys
import torch
def add_arguments(parser):
group = parser.add_argument_group(title='Megatron saver')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of Megatron repository')
group.add_argument('--target-tensor-parallel-size', type=int,
help='Target tensor model parallel size, defaults to the tensor parallel size '
'in the input checkpoint if provided by the loader, otherwise to 1')
group.add_argument('--target-pipeline-parallel-size', type=int,
help='Target tensor model parallel size, default to the pipeline parall size '
'in the input checkpoint if provided by the loader, otherwise to 1')
def save_checkpoint(queue, args):
# Search in directory above this
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
try:
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import save_checkpoint
from megatron.global_vars import set_global_variables, get_args
from megatron.model import ModelType
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
from megatron import mpu, fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
exit(1)
def queue_get(name=None):
val = queue.get()
if val == "exit":
print("Loader exited, exiting saver")
exit(1)
if name is not None and args.checking and val["name"] != name:
val_name = val["name"]
print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
exit(1)
if name is not None:
print(f"received {name}")
return val
def check_message(msg):
if not args.checking:
return
msg_name = msg.pop("name")
if len(msg.keys()) > 0:
print(f"Unexpected values in {msg_name}:")
for key in msg.keys():
print(f" {key}")
print(f"Exiting. If you want to ignore this, use the argument --no-checking.")
exit(1)
md = queue_get()
if args.target_tensor_parallel_size is None:
if hasattr(md, 'previous_tensor_parallel_size'):
args.target_tensor_parallel_size = md.previous_tensor_parallel_size
else:
print("loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. "
"Default to 1.")
args.target_tensor_parallel_size = 1
if args.target_pipeline_parallel_size is None:
if hasattr(md, 'previous_pipeline_parallel_size'):
args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size
else:
print("loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. "
"Default to 1.")
args.target_pipeline_parallel_size = 1
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
if args.target_tensor_parallel_size is not None and args.target_pipeline_parallel_size is not None:
os.environ["WORLD_SIZE"] = f'{args.target_tensor_parallel_size * args.target_pipeline_parallel_size}'
# We want all arguments to come from us
sys.argv = ['script.py',
'--num-layers', str(md.num_layers),
'--hidden-size', str(md.hidden_size),
'--seq-length', str(md.seq_length),
'--num-attention-heads', str(md.num_attention_heads),
'--max-position-embeddings', str(md.max_position_embeddings),
'--tokenizer-type', str(md.tokenizer_type),
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--no-initialization',
'--save-interval', '1',
'--save', args.save_dir
]
if md.make_vocab_size_divisible_by is not None:
sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)])
if md.params_dtype == torch.float16:
sys.argv.append('--fp16')
elif md.params_dtype == torch.bfloat16:
sys.argv.append('--bf16')
if md.model_type == 'BERT' and not md.bert_binary_head:
sys.argv.append('--bert-no-binary-head')
margs = parse_args()
validate_args(margs)
set_global_variables(margs)
# margs = megatron args
margs = get_args()
if hasattr(md, 'consumed_train_samples'):
margs.consumed_train_samples = md.consumed_train_samples
margs.consumed_valid_samples = md.consumed_valid_samples
print(f"Setting consumed_train_samples to {margs.consumed_train_samples}"
f" and consumed_valid_samples to {margs.consumed_valid_samples}")
else:
print("consumed_train_samples not provided.")
# Determine how to make our models
if md.model_type == 'GPT':
from pretrain_gpt import model_provider
margs.model_type = ModelType.encoder_or_decoder
elif md.model_type == 'BERT':
from pretrain_bert import model_provider
margs.model_type = ModelType.encoder_or_decoder
else:
raise Exception(f'unrecognized model type: {args.model_type}')
def get_models(count, dtype, pre_process, post_process):
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
return models
# fake initializing distributed
mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_rank(0)
fused_kernels.load(margs)
# Embeddings
#-----------
embeddings_msg = queue_get("embeddings")
pos_embed = embeddings_msg.pop("position embeddings")
orig_word_embed = embeddings_msg.pop("word embeddings")
check_message(embeddings_msg)
# Deal with padding
if md.true_vocab_size is not None:
# figure out what our padded vocab size is
orig_vocab_size = orig_word_embed.shape[0]
margs.padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs)
# Cut out extra padding we don't need
if orig_vocab_size > margs.padded_vocab_size:
full_word_embed = orig_word_embed[0:margs.padded_vocab_size,:]
# Expanding embedding to larger size by replicating final entry
elif orig_vocab_size < margs.padded_vocab_size:
padding_size = margs.padded_vocab_size - orig_vocab_size
full_word_embed = torch.cat((
orig_word_embed,
orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1)))
# Same size!
else:
full_word_embed = orig_word_embed
else:
print("Original vocab size not specified, leaving embedding table as-is. "
"If you've changed the tensor parallel size this could cause problems.")
margs.padded_vocab_size = orig_word_embed.shape[0]
full_word_embed = orig_word_embed
# Split into new tensor model parallel sizes
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
# Make models for first pipeline stage and fill in embeddings
mpu.initialize.set_pipeline_model_parallel_rank(0)
post_process = args.target_pipeline_parallel_size == 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
for tp_rank, model in enumerate(models):
print(f"word embeddings shape {model.language_model.embedding.word_embeddings.weight.shape}")
model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
# Transformer layers
#-------------------
total_layer_num = 0
for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models
if pp_rank > 0:
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == args.target_pipeline_parallel_size - 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
for layer in range(len(models[0].language_model.encoder.layers)):
msg = queue_get(f"transformer layer {total_layer_num}")
# duplicated tensors
input_layernorm_weight = msg.pop("input layernorm weight")
input_layernorm_bias = msg.pop("input layernorm bias")
dense_bias = msg.pop("dense bias")
post_layernorm_weight = msg.pop("post layernorm weight")
post_layernorm_bias = msg.pop("post layernorm bias")
mlp_l1_bias = msg.pop("mlp l1 bias")
# Split up the parallel tensors
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
# Save them to the model
for tp_rank in range(args.target_tensor_parallel_size):
l = models[tp_rank].language_model.encoder.layers[layer]
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
l.self_attention.dense.bias.data.copy_(dense_bias)
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
total_layer_num = total_layer_num + 1
check_message(msg)
if post_process:
msg = queue_get("final layernorm")
final_layernorm_weight = msg.pop("weight")
final_layernorm_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
if pp_rank != 0:
# Copy word embeddings to final pipeline rank
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
del final_layernorm_weight
del final_layernorm_bias
check_message(msg)
msg = queue_get()
if msg != "done" and msg["name"] == "pooler":
if not hasattr(models[0].language_model, 'pooler'):
print("ERROR: got a pooler, but model does not have one")
exit(1)
print("received pooler")
pooler_weight = msg.pop("weight")
pooler_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight)
models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias)
del pooler_weight
del pooler_bias
check_message(msg)
msg = queue_get()
if msg != "done" and msg["name"] == "lm head":
if not hasattr(models[0], 'lm_head'):
print("ERROR: got an lm head, but model does not have one")
exit(1)
print("received lm head")
lm_head_dense_weight = msg.pop("dense weight")
lm_head_dense_bias = msg.pop("dense bias")
lm_head_layernorm_weight = msg.pop("layernorm weight")
lm_head_layernorm_bias = msg.pop("layernorm bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
check_message(msg)
msg = queue_get()
if msg != "done" and msg["name"] == "binary head":
if not hasattr(models[0], 'binary_head'):
print("ERROR: got a binary head, but model does not have one")
exit(1)
print("received binary head")
binary_head_weight = msg.pop("weight")
binary_head_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].binary_head.weight.data.copy_(binary_head_weight)
models[tp_rank].binary_head.bias.data.copy_(binary_head_bias)
check_message(msg)
msg = queue_get()
if msg != "done":
print("ERROR: got some more data but was expecting to be done")
for tp_rank in range(args.target_tensor_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(tp_rank)
save_checkpoint(md.iteration, [models[tp_rank]], None, None)
print("Done!")
import argparse
import importlib
import torch.multiprocessing as mp
import os
import sys
# A loader is a python file with at least two functions
# - add_arguments - takes in a parser and adds any arguments needed
# - load_checkpoint - takes in the queue and parsed arguments
# A saver is similar but has save_checkpoint instead of
# load_checkpoint
# The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in messages in the
# following order, the saver should receive them in this order and
# save the checkpoints. A message consists of a python dictionary with
# a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting.
# - Metadata Namespace with the following attributes:
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
# num_layers - Number of transformer layers
# hidden_size
# seq_length
# num_attention_heads
# max_position_embeddings
# tokenizer_type
# iteration
# params_dtype
# bert_binary_head - Used only if model_type is BERT
# previous_tensor_parallel_size - Optional
# previous_pipeline_parallel_size - Optional
# true_vocab_size
# make_vocab_size_divisble_by
# consumed_train_samples
# consumed_valid_samples
# messages
# {
# "name": "embeddings"
# "position embeddings"
# "word embeddings"
# }
# (for each transformer layer):
# {
# "name": "transformer layer N"
# "input layernorm weight"
# "input layernorm bias"
# "qkv weight"
# "qkv bias"
# "dense weight"
# "dense bias"
# "post layernorm weight"
# "post layernorm bias"
# "mlp l0 weight"
# "mlp l0 bias"
# "mlp l1 weight"
# "mlp l1 bias"
# }
# {
# "name": "final layer norm"
# "weight"
# "bias"
# }
# if present (i.e. for BERT):
# {
# "name": "pooler"
# "weight"
# "bias"
# }
# {
# "name": "lm head"
# "dense weight"
# "dense bias"
# "layernorm weight"
# "layernorm bias"
# }
# {
# "name": "binary head"
# "weight"
# "bias"
# }
# - "done"
def load_plugin(plugin_type, name):
module_name = f"checkpoint_{plugin_type}_{name}"
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
module_name = name
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
if not hasattr(plugin, 'add_arguments'):
sys.exit(f"{module_name} module is not a plugin. Exiting.")
print(f"Loaded {module_name} as the {plugin_type}.")
return plugin
def main():
import argparse
parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments",
allow_abbrev=False, conflict_handler='resolve')
parser.add_argument('--model-type', type=str, required=True,
choices=['GPT', 'BERT'],
help='Type of the model')
parser.add_argument('--loader', type=str, default='megatron',
help='Module name to load checkpoint, should be on python path')
parser.add_argument('--saver', type=str, default='megatron',
help='Module name to save checkpoint, shdoul be on python path')
parser.add_argument('--load-dir', type=str, required=True,
help='Directory to load model checkpoint from')
parser.add_argument('--save-dir', type=str, required=True,
help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50,
help='Maximum number of tensors in the queue')
parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking')
known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver)
loader.add_arguments(parser)
saver.add_arguments(parser)
args = parser.parse_args()
queue = mp.Queue(maxsize=args.max_queue_size)
print("Starting saver...")
saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args))
saver_proc.start()
print("Starting loader...")
loader.load_checkpoint(queue, args)
print("Waiting for saver to complete...")
saver_proc.join()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment