# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import argparse import datetime import json import logging import os import random import time from datetime import timedelta from pathlib import Path import detr.util.misc as utils import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from detectron2.engine.launch import _find_free_port from detectron2.utils.file_io import PathManager from detr import datasets from detr.datasets import build_dataset, get_coco_api_from_dataset from detr.engine import evaluate, train_one_epoch from detr.models import build_model from torch.utils.data import DataLoader, DistributedSampler DEFAULT_TIMEOUT = timedelta(minutes=30) def get_args_parser(): parser = argparse.ArgumentParser("Set transformer detector", add_help=False) parser.add_argument("--lr", default=1e-4, type=float) parser.add_argument("--lr_backbone", default=1e-5, type=float) parser.add_argument("--batch_size", default=2, type=int) parser.add_argument("--weight_decay", default=1e-4, type=float) parser.add_argument("--epochs", default=300, type=int) parser.add_argument("--lr_drop", default=200, type=int) parser.add_argument( "--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm" ) # Model parameters parser.add_argument( "--frozen_weights", type=str, default=None, help="Path to the pretrained model. If set, only the mask head will be trained", ) # * Backbone parser.add_argument( "--backbone", default="resnet50", type=str, help="Name of the convolutional backbone to use", ) parser.add_argument( "--dilation", action="store_true", help="If true, we replace stride with dilation in the last convolutional block (DC5)", ) parser.add_argument( "--position_embedding", default="sine", type=str, choices=("sine", "learned"), help="Type of positional embedding to use on top of the image features", ) # * Transformer parser.add_argument( "--enc_layers", default=6, type=int, help="Number of encoding layers in the transformer", ) parser.add_argument( "--dec_layers", default=6, type=int, help="Number of decoding layers in the transformer", ) parser.add_argument( "--dim_feedforward", default=2048, type=int, help="Intermediate size of the feedforward layers in the transformer blocks", ) parser.add_argument( "--hidden_dim", default=256, type=int, help="Size of the embeddings (dimension of the transformer)", ) parser.add_argument( "--dropout", default=0.1, type=float, help="Dropout applied in the transformer" ) parser.add_argument( "--nheads", default=8, type=int, help="Number of attention heads inside the transformer's attentions", ) parser.add_argument( "--num_queries", default=100, type=int, help="Number of query slots" ) parser.add_argument("--pre_norm", action="store_true") # * Segmentation parser.add_argument( "--masks", action="store_true", help="Train segmentation head if the flag is provided", ) # Loss parser.add_argument( "--no_aux_loss", dest="aux_loss", action="store_false", help="Disables auxiliary decoding losses (loss at each layer)", ) # * Matcher parser.add_argument( "--set_cost_class", default=1, type=float, help="Class coefficient in the matching cost", ) parser.add_argument( "--set_cost_bbox", default=5, type=float, help="L1 box coefficient in the matching cost", ) parser.add_argument( "--set_cost_giou", default=2, type=float, help="giou box coefficient in the matching cost", ) # * Loss coefficients parser.add_argument("--mask_loss_coef", default=1, type=float) parser.add_argument("--dice_loss_coef", default=1, type=float) parser.add_argument("--bbox_loss_coef", default=5, type=float) parser.add_argument("--giou_loss_coef", default=2, type=float) parser.add_argument( "--eos_coef", default=0.1, type=float, help="Relative classification weight of the no-object class", ) # dataset parameters parser.add_argument("--dataset_file", default="coco") parser.add_argument( "--ade_path", type=str, default="manifold://winvision/tree/detectron2/ADEChallengeData2016/", ) parser.add_argument( "--coco_path", type=str, default="manifold://fair_vision_data/tree/" ) parser.add_argument( "--coco_panoptic_path", type=str, default="manifold://fair_vision_data/tree/" ) parser.add_argument("--remove_difficult", action="store_true") parser.add_argument( "--output-dir", default="", help="path where to save, empty for no saving" ) parser.add_argument( "--device", default="cuda", help="device to use for training / testing" ) parser.add_argument("--seed", default=42, type=int) parser.add_argument("--resume", default="", help="resume from checkpoint") parser.add_argument( "--start_epoch", default=0, type=int, metavar="N", help="start epoch" ) parser.add_argument("--eval", action="store_true") parser.add_argument("--num_workers", default=2, type=int) # distributed training parameters parser.add_argument( "--num-gpus", type=int, default=8, help="number of gpus *per machine*" ) parser.add_argument( "--num-machines", type=int, default=1, help="total number of machines" ) parser.add_argument( "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)", ) parser.add_argument( "--dist-url", default="env://", help="url used to set up distributed training" ) return parser def main(args): # utils.init_distributed_mode(args) if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion, postprocessors = build_model(args) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("number of params:", n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW( param_dicts, lr=args.lr, weight_decay=args.weight_decay ) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) dataset_train = build_dataset(image_set="train", args=args) dataset_val = build_dataset(image_set="val", args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler( sampler_train, args.batch_size, drop_last=True ) data_loader_train = DataLoader( dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers, ) data_loader_val = DataLoader( dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers, ) if args.dataset_file == "coco_panoptic": # We also evaluate AP during panoptic training, on original coco DS coco_val = datasets.coco.build("val", args) base_ds = get_coco_api_from_dataset(coco_val) else: base_ds = get_coco_api_from_dataset(dataset_val) if args.frozen_weights is not None: checkpoint = torch.load(args.frozen_weights, map_location="cpu") model_without_ddp.detr.load_state_dict(checkpoint["model"]) if args.resume: if args.resume.startswith("https"): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location="cpu", check_hash=True ) else: checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) if ( not args.eval and "optimizer" in checkpoint and "lr_scheduler" in checkpoint and "epoch" in checkpoint ): optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 if args.eval: test_stats, coco_evaluator = evaluate( model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir, ) if args.output_dir: with PathManager.open(os.path.join(args.output_dir, "eval.pth"), "wb") as f: utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, f) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm, ) lr_scheduler.step() if args.output_dir: checkpoint_paths = [] # os.path.join(args.output_dir, 'checkpoint.pth')] # extra checkpoint before LR drop and every 10 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 10 == 0: checkpoint_paths.append( os.path.join(args.output_dir, f"checkpoint{epoch:04}.pth") ) for checkpoint_path in checkpoint_paths: with PathManager.open(checkpoint_path, "wb") as f: if args.gpu == 0 and args.machine_rank == 0: utils.save_on_master( { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "epoch": epoch, "args": args, }, f, ) test_stats, coco_evaluator = evaluate( model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir, ) log_stats = { **{f"train_{k}": v for k, v in train_stats.items()}, **{f"test_{k}": v for k, v in test_stats.items()}, "epoch": epoch, "n_parameters": n_parameters, } if args.output_dir and utils.is_main_process(): with PathManager.open(os.path.join(args.output_dir, "log.txt"), "w") as f: f.write(json.dumps(log_stats) + "\n") # for evaluation logs if coco_evaluator is not None: PathManager.mkdirs(os.path.join(args.output_dir, "eval")) if "bbox" in coco_evaluator.coco_eval: filenames = ["latest.pth"] if epoch % 50 == 0: filenames.append(f"{epoch:03}.pth") for name in filenames: with PathManager.open( os.path.join(args.output_dir, "eval", name), "wb" ) as f: torch.save(coco_evaluator.coco_eval["bbox"].eval, f) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Training time {}".format(total_time_str)) def launch( main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=(), timeout=DEFAULT_TIMEOUT, ): """ Launch multi-gpu or distributed training. This function must be called on all machines involved in the training. It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine. Args: main_func: a function that will be called by `main_func(*args)` num_gpus_per_machine (int): number of GPUs per machine num_machines (int): the total number of machines machine_rank (int): the rank of this machine dist_url (str): url to connect to for distributed jobs, including protocol e.g. "tcp://127.0.0.1:8686". Can be set to "auto" to automatically select a free port on localhost timeout (timedelta): timeout of the distributed workers args (tuple): arguments passed to main_func """ world_size = num_machines * num_gpus_per_machine args[0].distributed = world_size > 1 if args[0].distributed: # https://github.com/pytorch/pytorch/pull/14391 # TODO prctl in spawned processes if dist_url == "auto": assert ( num_machines == 1 ), "dist_url=auto not supported in multi-machine jobs." port = _find_free_port() dist_url = f"tcp://127.0.0.1:{port}" if num_machines > 1 and dist_url.startswith("file://"): logger = logging.getLogger(__name__) logger.warning( "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://" ) mp.spawn( _distributed_worker, nprocs=num_gpus_per_machine, args=( main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args, timeout, ), daemon=False, ) else: main_func(*args) def synchronize(): """ Helper function to synchronize (barrier) among all processes when using distributed training """ if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() def _distributed_worker( local_rank, main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args, timeout=DEFAULT_TIMEOUT, ): assert ( torch.cuda.is_available() ), "cuda is not available. Please check your installation." global_rank = machine_rank * num_gpus_per_machine + local_rank try: dist.init_process_group( backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank, timeout=timeout, ) except Exception as e: logger = logging.getLogger(__name__) logger.error("Process group URL: {}".format(dist_url)) raise e # synchronize is needed here to prevent a possible timeout after calling init_process_group # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 synchronize() assert num_gpus_per_machine <= torch.cuda.device_count() torch.cuda.set_device(local_rank) args[0].gpu = local_rank # Setup the local process group (which contains ranks within the same machine) # assert comm._LOCAL_PROCESS_GROUP is None # num_machines = world_size // num_gpus_per_machine # for i in range(num_machines): # ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)) # pg = dist.new_group(ranks_on_i) # if i == machine_rank: # comm._LOCAL_PROCESS_GROUP = pg main_func(*args) def invoke_main() -> None: parser = argparse.ArgumentParser( "DETR training and evaluation script", parents=[get_args_parser()] ) args = parser.parse_args() if args.output_dir: PathManager.mkdirs(args.output_dir) print("Command Line Args:", args) launch( main, args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, args=(args,), ) if __name__ == "__main__": invoke_main() # pragma: no cover