# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Helper script to convert models trained with the main version of DETR to be used with the Detectron2 version. """ import json import argparse import numpy as np import torch def parse_args(): parser = argparse.ArgumentParser("D2 model converter") parser.add_argument("--source_model", default="", type=str, help="Path or url to the DETR model to convert") parser.add_argument("--output_model", default="", type=str, help="Path where to save the converted model") return parser.parse_args() def main(): args = parse_args() # D2 expects contiguous classes, so we need to remap the 92 classes from DETR # fmt: off coco_idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91] # fmt: on coco_idx = np.array(coco_idx) if args.source_model.startswith("https"): checkpoint = torch.hub.load_state_dict_from_url(args.source_model, map_location="cpu", check_hash=True) else: checkpoint = torch.load(args.source_model, map_location="cpu") model_to_convert = checkpoint["model"] model_converted = {} for k in model_to_convert.keys(): old_k = k if "backbone" in k: k = k.replace("backbone.0.body.", "") if "layer" not in k: k = "stem." + k for t in [1, 2, 3, 4]: k = k.replace(f"layer{t}", f"res{t + 1}") for t in [1, 2, 3]: k = k.replace(f"bn{t}", f"conv{t}.norm") k = k.replace("downsample.0", "shortcut") k = k.replace("downsample.1", "shortcut.norm") k = "backbone.0.backbone." + k k = "detr." + k print(old_k, "->", k) if "class_embed" in old_k: v = model_to_convert[old_k].detach() if v.shape[0] == 92: shape_old = v.shape model_converted[k] = v[coco_idx] print("Head conversion: changing shape from {} to {}".format(shape_old, model_converted[k].shape)) continue model_converted[k] = model_to_convert[old_k].detach() model_to_save = {"model": model_converted} torch.save(model_to_save, args.output_model) if __name__ == "__main__": main()