Commit 84b97851 authored by chenych's avatar chenych
Browse files

First commit

parents
Pipeline #656 failed with stages
in 0 seconds
# ------------------------------------------------------------------------
# H-DETR
# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
# Licensed under the MIT-style license found in the LICENSE file in the root directory
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import argparse
import datetime
import json
import random
import time
import os
import torch
import datasets
import numpy as np
import util.misc as utils
import datasets.samplers as samplers
from pathlib import Path
from torch.utils.data import DataLoader
from datasets import build_dataset, get_coco_api_from_dataset
from engine import evaluate, train_one_epoch
from models import build_model
def get_args_parser():
parser = argparse.ArgumentParser("Deformable DETR Detector", add_help=False)
parser.add_argument("--lr", default=2e-4, type=float)
parser.add_argument(
"--lr_backbone_names", default=["backbone.0"], type=str, nargs="+"
)
parser.add_argument("--lr_backbone", default=2e-5, type=float)
parser.add_argument(
"--lr_linear_proj_names",
default=["reference_points", "sampling_offsets"],
type=str,
nargs="+",
)
parser.add_argument("--lr_linear_proj_mult", default=0.1, type=float)
parser.add_argument("--batch_size", default=4, type=int)
parser.add_argument("--weight_decay", default=1e-4, type=float)
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--lr_drop", default=40, type=int)
parser.add_argument("--lr_drop_epochs", default=None, type=int, nargs="+")
parser.add_argument(
"--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm"
)
parser.add_argument("--sgd", action="store_true")
# Variants of Deformable DETR
parser.add_argument("--with_box_refine", default=False, action="store_true")
parser.add_argument("--two_stage", default=False, action="store_true")
# Model parameters
parser.add_argument(
"--frozen_weights",
type=str,
default=None,
help="Path to the pretrained model. If set, only the mask head will be trained",
)
# * Backbone
parser.add_argument(
"--backbone",
default="resnet50",
type=str,
help="Name of the convolutional backbone to use",
)
parser.add_argument(
"--dilation",
action="store_true",
help="If true, we replace stride with dilation in the last convolutional block (DC5)",
)
parser.add_argument(
"--position_embedding",
default="sine",
type=str,
choices=("sine", "learned"),
help="Type of positional embedding to use on top of the image features",
)
parser.add_argument(
"--position_embedding_scale",
default=2 * np.pi,
type=float,
help="position / size * scale",
)
parser.add_argument(
"--num_feature_levels", default=4, type=int, help="number of feature levels"
)
# swin backbone
parser.add_argument(
"--pretrained_backbone_path",
default="./swin_tiny_patch4_window7_224.pkl",
type=str,
)
parser.add_argument("--drop_path_rate", default=0.2, type=float)
# * Transformer
parser.add_argument(
"--enc_layers",
default=6,
type=int,
help="Number of encoding layers in the transformer",
)
parser.add_argument(
"--dec_layers",
default=6,
type=int,
help="Number of decoding layers in the transformer",
)
parser.add_argument(
"--dim_feedforward",
default=2048,
type=int,
help="Intermediate size of the feedforward layers in the transformer blocks",
)
parser.add_argument(
"--hidden_dim",
default=256,
type=int,
help="Size of the embeddings (dimension of the transformer)",
)
parser.add_argument(
"--dropout", default=0.1, type=float, help="Dropout applied in the transformer"
)
parser.add_argument(
"--nheads",
default=8,
type=int,
help="Number of attention heads inside the transformer's attentions",
)
parser.add_argument(
"--num_queries_one2one",
default=300,
type=int,
help="Number of query slots for one-to-one matching",
)
parser.add_argument(
"--num_queries_one2many",
default=0,
type=int,
help="Number of query slots for one-to-many matchining",
)
parser.add_argument("--dec_n_points", default=4, type=int)
parser.add_argument("--enc_n_points", default=4, type=int)
# Deformable DETR tricks
parser.add_argument("--mixed_selection", action="store_true", default=False)
parser.add_argument("--look_forward_twice", action="store_true", default=False)
# hybrid branch
parser.add_argument("--k_one2many", default=5, type=int)
parser.add_argument("--lambda_one2many", default=1.0, type=float)
# * Segmentation
parser.add_argument(
"--masks",
action="store_true",
help="Train segmentation head if the flag is provided",
)
# Loss
parser.add_argument(
"--no_aux_loss",
dest="aux_loss",
action="store_false",
help="Disables auxiliary decoding losses (loss at each layer)",
)
# * Matcher
parser.add_argument(
"--set_cost_class",
default=2,
type=float,
help="Class coefficient in the matching cost",
)
parser.add_argument(
"--set_cost_bbox",
default=5,
type=float,
help="L1 box coefficient in the matching cost",
)
parser.add_argument(
"--set_cost_giou",
default=2,
type=float,
help="giou box coefficient in the matching cost",
)
# * Loss coefficients
parser.add_argument("--mask_loss_coef", default=1, type=float)
parser.add_argument("--dice_loss_coef", default=1, type=float)
parser.add_argument("--cls_loss_coef", default=2, type=float)
parser.add_argument("--bbox_loss_coef", default=5, type=float)
parser.add_argument("--giou_loss_coef", default=2, type=float)
parser.add_argument("--focal_alpha", default=0.25, type=float)
# dataset parameters
parser.add_argument("--dataset_file", default="coco")
parser.add_argument("--coco_path", default="/public/DL_DATA/COCO2017", type=str)
parser.add_argument("--coco_panoptic_path", type=str)
parser.add_argument("--remove_difficult", action="store_true")
parser.add_argument(
"--output_dir", default="", help="path where to save, empty for no saving"
)
parser.add_argument(
"--device", default="cuda", help="device to use for training / testing"
)
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--resume", default="", help="resume from checkpoint")
parser.add_argument(
"--start_epoch", default=0, type=int, metavar="N", help="start epoch"
)
parser.add_argument("--num_workers", default=2, type=int)
parser.add_argument(
"--cache_mode",
default=False,
action="store_true",
help="whether to cache images on memory",
)
# * eval technologies
parser.add_argument("--eval", action="store_true")
# eval in training set
parser.add_argument("--eval_in_training_set", default=False, action="store_true")
# topk for eval
parser.add_argument("--topk", default=100, type=int)
# * training technologies
parser.add_argument("--use_fp16", default=False, action="store_true")
parser.add_argument("--use_checkpoint", default=False, action="store_true")
# * logging technologies
parser.add_argument("--use_wandb", action="store_true", default=False)
return parser
def main(args):
utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))
if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentation only"
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
model, criterion, postprocessors = build_model(args)
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of params:", n_parameters)
dataset_train = build_dataset(image_set="train", args=args)
if not args.eval_in_training_set:
dataset_val = build_dataset(
image_set="val", args=args, eval_in_training_set=False,
)
else:
print("eval in the training set")
dataset_val = build_dataset(
image_set="train", args=args, eval_in_training_set=True,
)
if args.distributed:
if args.cache_mode:
sampler_train = samplers.NodeDistributedSampler(dataset_train)
sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = samplers.DistributedSampler(dataset_train)
sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True
)
data_loader_train = DataLoader(
dataset_train,
batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn,
num_workers=args.num_workers,
pin_memory=True,
)
data_loader_val = DataLoader(
dataset_val,
args.batch_size,
sampler=sampler_val,
drop_last=False,
collate_fn=utils.collate_fn,
num_workers=args.num_workers,
pin_memory=True,
)
# lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
def match_name_keywords(n, name_keywords):
out = False
for b in name_keywords:
if b in n:
out = True
break
return out
for n, p in model_without_ddp.named_parameters():
print(n)
param_dicts = [
{
"params": [
p
for n, p in model_without_ddp.named_parameters()
if not match_name_keywords(n, args.lr_backbone_names)
and not match_name_keywords(n, args.lr_linear_proj_names)
and p.requires_grad
],
"lr": args.lr,
},
{
"params": [
p
for n, p in model_without_ddp.named_parameters()
if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad
],
"lr": args.lr_backbone,
},
{
"params": [
p
for n, p in model_without_ddp.named_parameters()
if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad
],
"lr": args.lr * args.lr_linear_proj_mult,
},
]
if args.sgd:
optimizer = torch.optim.SGD(
param_dicts, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay
)
else:
optimizer = torch.optim.AdamW(
param_dicts, lr=args.lr, weight_decay=args.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.dataset_file == "coco_panoptic":
# We also evaluate AP during panoptic training, on original coco DS
coco_val = datasets.coco.build("val", args)
base_ds = get_coco_api_from_dataset(coco_val)
else:
base_ds = get_coco_api_from_dataset(dataset_val)
if args.frozen_weights is not None:
checkpoint = torch.load(args.frozen_weights, map_location="cpu")
model_without_ddp.detr.load_state_dict(checkpoint["model"])
output_dir = Path(args.output_dir)
if args.resume and os.path.exists(args.resume):
if args.resume.startswith("https"):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location="cpu", check_hash=True
)
else:
checkpoint = torch.load(args.resume, map_location="cpu")
missing_keys, unexpected_keys = model_without_ddp.load_state_dict(
checkpoint["model"], strict=False
)
unexpected_keys = [
k
for k in unexpected_keys
if not (k.endswith("total_params") or k.endswith("total_ops"))
]
if len(missing_keys) > 0:
print("Missing Keys: {}".format(missing_keys))
if len(unexpected_keys) > 0:
print("Unexpected Keys: {}".format(unexpected_keys))
if (
not args.eval
and "optimizer" in checkpoint
and "lr_scheduler" in checkpoint
and "epoch" in checkpoint
):
import copy
p_groups = copy.deepcopy(optimizer.param_groups)
optimizer.load_state_dict(checkpoint["optimizer"])
for pg, pg_old in zip(optimizer.param_groups, p_groups):
pg["lr"] = pg_old["lr"]
pg["initial_lr"] = pg_old["initial_lr"]
print(optimizer.param_groups)
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
# todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
args.override_resumed_lr_drop = True
if args.override_resumed_lr_drop:
print(
"Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler."
)
lr_scheduler.step_size = args.lr_drop
lr_scheduler.base_lrs = list(
map(lambda group: group["initial_lr"], optimizer.param_groups)
)
lr_scheduler.step(lr_scheduler.last_epoch)
args.start_epoch = checkpoint["epoch"] + 1
# check the resumed model
if not args.eval:
test_stats, coco_evaluator = evaluate(
model,
criterion,
postprocessors,
data_loader_val,
base_ds,
device,
args.output_dir,
use_wandb=args.use_wandb,
)
if args.eval:
test_stats, coco_evaluator = evaluate(
model,
criterion,
postprocessors,
data_loader_val,
base_ds,
device,
args.output_dir,
use_wandb=args.use_wandb,
)
if args.output_dir:
utils.save_on_master(
coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth"
)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(
model,
criterion,
data_loader_train,
optimizer,
device,
epoch,
args.clip_max_norm,
k_one2many=args.k_one2many,
lambda_one2many=args.lambda_one2many,
use_wandb=args.use_wandb,
use_fp16=args.use_fp16,
)
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / "checkpoint.pth"]
# extra checkpoint before LR drop and every 5 epochs
checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth")
for checkpoint_path in checkpoint_paths:
utils.save_on_master(
{
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
},
checkpoint_path,
)
test_stats, coco_evaluator = evaluate(
model,
criterion,
postprocessors,
data_loader_val,
base_ds,
device,
args.output_dir,
use_wandb=args.use_wandb,
)
log_stats = {
**{f"train_{k}": v for k, v in train_stats.items()},
**{f"test_{k}": v for k, v in test_stats.items()},
"epoch": epoch,
"n_parameters": n_parameters,
}
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(output_dir / "eval").mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ["latest.pth"]
if epoch % 50 == 0:
filenames.append(f"{epoch:03}.pth")
for name in filenames:
torch.save(
coco_evaluator.coco_eval["bbox"].eval,
output_dir / "eval" / name,
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
"Deformable DETR training and evaluation script", parents=[get_args_parser()]
)
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
# -*- coding: utf-8 -*-
from .checkpoint import load_checkpoint
__all__ = ["load_checkpoint"]
# Copyright (c) Open-MMLab. All rights reserved.
import io
import os
import os.path as osp
import pkgutil
import time
import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
import torch
import torchvision
from torch.optim import Optimizer
from torch.utils import model_zoo
from torch.nn import functional as F
import mmcv
from mmcv.fileio import FileClient
from mmcv.fileio import load as load_file
from mmcv.parallel import is_module_wrapper
from mmcv.utils import mkdir_or_exist
from mmcv.runner import get_dist_info
ENV_MMCV_HOME = "MMCV_HOME"
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
DEFAULT_CACHE_DIR = "~/.cache"
def _get_mmcv_home():
mmcv_home = os.path.expanduser(
os.getenv(
ENV_MMCV_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "mmcv"),
)
)
mkdir_or_exist(mmcv_home)
return mmcv_home
def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=""):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict,
prefix,
local_metadata,
True,
all_missing_keys,
unexpected_keys,
err_msg,
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(module)
load = None # break load->load reference cycle
# ignore "num_batches_tracked" of BN layers
missing_keys = [key for key in all_missing_keys if "num_batches_tracked" not in key]
if unexpected_keys:
err_msg.append(
"unexpected key in source " f'state_dict: {", ".join(unexpected_keys)}\n'
)
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n'
)
rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(0, "The model and loaded state dict do not match exactly\n")
err_msg = "\n".join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
def load_url_dist(url, model_dir=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get("LOCAL_RANK", rank))
if rank == 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
return checkpoint
def load_pavimodel_dist(model_path, map_location=None):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
try:
from pavi import modelcloud
except ImportError:
raise ImportError("Please install pavi to load checkpoint from modelcloud.")
rank, world_size = get_dist_info()
rank = int(os.environ.get("LOCAL_RANK", rank))
if rank == 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
return checkpoint
def load_fileclient_dist(filename, backend, map_location):
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get("LOCAL_RANK", rank))
allowed_backends = ["ceph"]
if backend not in allowed_backends:
raise ValueError(f"Load from Backend {backend} is not supported.")
if rank == 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f"torchvision.models.{name}")
if hasattr(_zoo, "model_urls"):
_urls = getattr(_zoo, "model_urls")
model_urls.update(_urls)
return model_urls
def get_external_models():
mmcv_home = _get_mmcv_home()
default_json_path = osp.join(mmcv.__path__[0], "model_zoo/open_mmlab.json")
default_urls = load_file(default_json_path)
assert isinstance(default_urls, dict)
external_json_path = osp.join(mmcv_home, "open_mmlab.json")
if osp.exists(external_json_path):
external_urls = load_file(external_json_path)
assert isinstance(external_urls, dict)
default_urls.update(external_urls)
return default_urls
def get_mmcls_models():
mmcls_json_path = osp.join(mmcv.__path__[0], "model_zoo/mmcls.json")
mmcls_urls = load_file(mmcls_json_path)
return mmcls_urls
def get_deprecated_model_names():
deprecate_json_path = osp.join(mmcv.__path__[0], "model_zoo/deprecated.json")
deprecate_urls = load_file(deprecate_json_path)
assert isinstance(deprecate_urls, dict)
return deprecate_urls
def _process_mmcls_checkpoint(checkpoint):
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith("backbone."):
new_state_dict[k[9:]] = v
new_checkpoint = dict(state_dict=new_state_dict)
return new_checkpoint
def _load_checkpoint(filename, map_location=None):
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns:
dict | OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
if filename.startswith("modelzoo://"):
warnings.warn(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead'
)
model_urls = get_torchvision_models()
model_name = filename[11:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith("torchvision://"):
model_urls = get_torchvision_models()
model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith("open-mmlab://"):
model_urls = get_external_models()
model_name = filename[13:]
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(
f"open-mmlab://{model_name} is deprecated in favor "
f"of open-mmlab://{deprecated_urls[model_name]}"
)
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(("http://", "https://")):
checkpoint = load_url_dist(model_url)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise IOError(f"{filename} is not a checkpoint file")
checkpoint = torch.load(filename, map_location=map_location)
elif filename.startswith("mmcls://"):
model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_url_dist(model_urls[model_name])
checkpoint = _process_mmcls_checkpoint(checkpoint)
elif filename.startswith(("http://", "https://")):
checkpoint = load_url_dist(filename)
elif filename.startswith("pavi://"):
model_path = filename[7:]
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
elif filename.startswith("s3://"):
checkpoint = load_fileclient_dist(
filename, backend="ceph", map_location=map_location
)
else:
if not osp.isfile(filename):
raise IOError(f"{filename} is not a checkpoint file")
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
def load_checkpoint(model, filename, map_location="cpu", strict=False, logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(f"No state_dict found in checkpoint file {filename}")
# get state_dict from checkpoint
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
elif "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith("module."):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# for MoBY, load model of online branch
if sorted(list(state_dict.keys()))[0].startswith("encoder"):
state_dict = {
k.replace("encoder.", ""): v
for k, v in state_dict.items()
if k.startswith("encoder.")
}
# reshape absolute position embedding
if state_dict.get("absolute_pos_embed") is not None:
absolute_pos_embed = state_dict["absolute_pos_embed"]
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = model.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H * W:
logger.warning("Error in loading absolute_pos_embed, pass")
else:
state_dict["absolute_pos_embed"] = absolute_pos_embed.view(
N2, H, W, C2
).permute(0, 3, 1, 2)
# interpolate position bias table if needed
relative_position_bias_table_keys = [
k for k in state_dict.keys() if "relative_position_bias_table" in k
]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
table_current = model.state_dict()[table_key]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f"Error in loading {table_key}, pass")
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
table_pretrained_resized = F.interpolate(
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2),
mode="bicubic",
)
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(
1, 0
)
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def weights_to_cpu(state_dict):
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
return state_dict_cpu
def _save_to_state_dict(module, destination, prefix, keep_vars):
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
Args:
module (nn.Module): The module to generate state_dict.
destination (dict): A dict where state will be stored.
prefix (str): The prefix for parameters and buffers used in this
module.
"""
for name, param in module._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in module._buffers.items():
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module, destination=None, prefix="", keep_vars=False):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
This method is modified from :meth:`torch.nn.Module.state_dict` to
recursively check parallel module in case that the model has a complicated
structure, e.g., nn.Module(nn.Module(DDP)).
Args:
module (nn.Module): The module to generate state_dict.
destination (OrderedDict): Returned dict for the state of the
module.
prefix (str): Prefix of the key.
keep_vars (bool): Whether to keep the variable property of the
parameters. Default: False.
Returns:
dict: A dictionary containing a whole state of the module.
"""
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars)
for name, child in module._modules.items():
if child is not None:
get_state_dict(child, destination, prefix + name + ".", keep_vars=keep_vars)
for hook in module._state_dict_hooks.values():
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f"meta must be a dict or None, but got {type(meta)}")
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, "CLASSES") and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint["optimizer"] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint["optimizer"] = {}
for name, optim in optimizer.items():
checkpoint["optimizer"][name] = optim.state_dict()
if filename.startswith("pavi://"):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError("Please install pavi to load checkpoint from modelcloud.")
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, "wb") as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, "wb") as f:
torch.save(checkpoint, f)
f.flush()
# Copyright (c) Open-MMLab. All rights reserved.
from .checkpoint import save_checkpoint
from .epoch_based_runner import EpochBasedRunnerAmp
__all__ = ["EpochBasedRunnerAmp", "save_checkpoint"]
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
from tempfile import TemporaryDirectory
import torch
from torch.optim import Optimizer
import mmcv
from mmcv.parallel import is_module_wrapper
from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
try:
import apex
except:
print("apex is not installed")
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
``optimizer``, ``amp``. By default ``meta`` will contain version
and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f"meta must be a dict or None, but got {type(meta)}")
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, "CLASSES") and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint["optimizer"] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint["optimizer"] = {}
for name, optim in optimizer.items():
checkpoint["optimizer"][name] = optim.state_dict()
# save amp state dict in the checkpoint
checkpoint["amp"] = apex.amp.state_dict()
if filename.startswith("pavi://"):
try:
from pavi import modelcloud
from pavi.exception import NodeNotFoundError
except ImportError:
raise ImportError("Please install pavi to load checkpoint from modelcloud.")
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, "wb") as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, "wb") as f:
torch.save(checkpoint, f)
f.flush()
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import platform
import shutil
import torch
from torch.optim import Optimizer
import mmcv
from mmcv.runner import RUNNERS, EpochBasedRunner
from .checkpoint import save_checkpoint
try:
import apex
except:
print("apex is not installed")
@RUNNERS.register_module()
class EpochBasedRunnerAmp(EpochBasedRunner):
"""Epoch-based Runner with AMP support.
This runner train models epoch by epoch.
"""
def save_checkpoint(
self,
out_dir,
filename_tmpl="epoch_{}.pth",
save_optimizer=True,
meta=None,
create_symlink=True,
):
"""Save the checkpoint.
Args:
out_dir (str): The directory that checkpoints are saved.
filename_tmpl (str, optional): The checkpoint filename template,
which contains a placeholder for the epoch number.
Defaults to 'epoch_{}.pth'.
save_optimizer (bool, optional): Whether to save the optimizer to
the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
create_symlink (bool, optional): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
"""
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
elif isinstance(meta, dict):
meta.update(epoch=self.epoch + 1, iter=self.iter)
else:
raise TypeError(f"meta should be a dict or None, but got {type(meta)}")
if self.meta is not None:
meta.update(self.meta)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
dst_file = osp.join(out_dir, "latest.pth")
if platform.system() != "Windows":
mmcv.symlink(filename, dst_file)
else:
shutil.copy(filepath, dst_file)
def resume(self, checkpoint, resume_optimizer=True, map_location="default"):
if map_location == "default":
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id),
)
else:
checkpoint = self.load_checkpoint(checkpoint)
else:
checkpoint = self.load_checkpoint(checkpoint, map_location=map_location)
self._epoch = checkpoint["meta"]["epoch"]
self._iter = checkpoint["meta"]["iter"]
if "optimizer" in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint["optimizer"])
elif isinstance(self.optimizer, dict):
for k in self.optimizer.keys():
self.optimizer[k].load_state_dict(checkpoint["optimizer"][k])
else:
raise TypeError(
"Optimizer should be dict or torch.optim.Optimizer "
f"but got {type(self.optimizer)}"
)
if "amp" in checkpoint:
apex.amp.load_state_dict(checkpoint["amp"])
self.logger.info("load amp state dict")
self.logger.info("resumed epoch %d, iter %d", self.epoch, self.iter)
# 模型唯一标识
modelCode=475
# 模型名称
modelName=hdetr_pytorch
# 模型描述
modelDescription=HDETR引入了一种混合匹配方案,这个新的匹配机制允许将多个查询分配给每个正样本,从而提高了训练效果,适用于多种视觉任务如目标检测、3D物体检测、姿势估计和对象跟踪等。
# 应用场景
appScenario=推理,训练,目标检测,网安,交通,政府
# 框架类型
frameType=PyTorch
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
from .deformable_detr import build, build_test
def build_model(args):
return build(args)
def build_test_model(args):
return build_test(args)
# ------------------------------------------------------------------------
# H-DETR
# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
# Licensed under the MIT-style license found in the LICENSE file in the root directory
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
"""
Backbone modules.
"""
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List
from util.misc import NestedTensor, is_main_process
from .position_encoding import build_position_encoding
from .swin_transformer import SwinTransformer
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n, eps=1e-5):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
self.eps = eps
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = self.eps
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool
):
super().__init__()
for name, parameter in backbone.named_parameters():
if (
not train_backbone
or "layer2" not in name
and "layer3" not in name
and "layer4" not in name
):
parameter.requires_grad_(False)
if return_interm_layers:
# return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
self.strides = [8, 16, 32]
self.num_channels = [512, 1024, 2048]
else:
return_layers = {"layer4": "0"}
self.strides = [32]
self.num_channels = [2048]
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(
self,
name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool,
):
norm_layer = FrozenBatchNorm2d
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(),
norm_layer=norm_layer,
)
assert name not in ("resnet18", "resnet34"), "number of channels are hard coded"
super().__init__(backbone, train_backbone, return_interm_layers)
if dilation:
self.strides[-1] = self.strides[-1] // 2
class TransformerBackbone(nn.Module):
def __init__(
self, backbone: str, train_backbone: bool, return_interm_layers: bool, args
):
super().__init__()
out_indices = (1, 2, 3)
if backbone == "swin_tiny":
backbone = SwinTransformer(
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
ape=False,
drop_path_rate=args.drop_path_rate,
patch_norm=True,
use_checkpoint=True,
out_indices=out_indices,
)
embed_dim = 96
backbone.init_weights(args.pretrained_backbone_path)
elif backbone == "swin_small":
backbone = SwinTransformer(
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
ape=False,
drop_path_rate=args.drop_path_rate,
patch_norm=True,
use_checkpoint=True,
out_indices=out_indices,
)
embed_dim = 96
backbone.init_weights(args.pretrained_backbone_path)
elif backbone == "swin_large":
backbone = SwinTransformer(
embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=7,
ape=False,
drop_path_rate=args.drop_path_rate,
patch_norm=True,
use_checkpoint=True,
out_indices=out_indices,
)
embed_dim = 192
backbone.init_weights(args.pretrained_backbone_path)
elif backbone == "swin_large_window12":
backbone = SwinTransformer(
pretrain_img_size=384,
embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=12,
ape=False,
drop_path_rate=args.drop_path_rate,
patch_norm=True,
use_checkpoint=True,
out_indices=out_indices,
)
embed_dim = 192
backbone.init_weights(args.pretrained_backbone_path)
else:
raise NotImplementedError
for name, parameter in backbone.named_parameters():
# TODO: freeze some layers?
if not train_backbone:
parameter.requires_grad_(False)
if return_interm_layers:
self.strides = [8, 16, 32]
self.num_channels = [
embed_dim * 2,
embed_dim * 4,
embed_dim * 8,
]
else:
self.strides = [32]
self.num_channels = [embed_dim * 8]
self.body = backbone
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
self.strides = backbone.strides
self.num_channels = backbone.num_channels
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in sorted(xs.items()):
out.append(x)
# position encoding
for x in out:
pos.append(self[1](x).to(x.tensors.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks or (args.num_feature_levels > 1)
if "resnet" in args.backbone:
backbone = Backbone(
args.backbone, train_backbone, return_interm_layers, args.dilation,
)
else:
backbone = TransformerBackbone(
args.backbone, train_backbone, return_interm_layers, args
)
model = Joiner(backbone, position_embedding)
return model
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn_func import MSDeformAttnFunction
This diff is collapsed.
#!/usr/bin/env bash
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
python setup.py build install
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from .ms_deform_attn import MSDeformAttn
This diff is collapsed.
This diff is collapsed.
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment