From cbc255859db08823102ecfca7b1d0c842854832c Mon Sep 17 00:00:00 2001 From: limm Date: Tue, 24 Jun 2025 17:46:51 +0800 Subject: [PATCH] add mmpretrain/ part --- mmpretrain/__init__.py | 28 + mmpretrain/apis/__init__.py | 22 + mmpretrain/apis/base.py | 390 ++++ mmpretrain/apis/feature_extractor.py | 130 ++ mmpretrain/apis/image_caption.py | 166 ++ mmpretrain/apis/image_classification.py | 223 ++ mmpretrain/apis/image_retrieval.py | 288 +++ mmpretrain/apis/model.py | 408 ++++ mmpretrain/apis/multimodal_retrieval.py | 603 ++++++ mmpretrain/apis/nlvr.py | 150 ++ mmpretrain/apis/utils.py | 270 +++ mmpretrain/apis/visual_grounding.py | 182 ++ mmpretrain/apis/visual_question_answering.py | 183 ++ .../configs/_base_/datasets/cifar10_bs16.py | 52 + .../configs/_base_/datasets/cub_bs8_384.py | 59 + .../_base_/datasets/imagenet21k_bs128.py | 35 + .../_base_/datasets/imagenet_bs128_mbv3.py | 75 + .../_base_/datasets/imagenet_bs256_beitv2.py | 53 + .../configs/_base_/datasets/imagenet_bs32.py | 62 + .../datasets/imagenet_bs32_pil_resize.py | 60 + .../_base_/datasets/imagenet_bs32_simclr.py | 63 + .../_base_/datasets/imagenet_bs512_mae.py | 40 + .../datasets/imagenet_bs64_pil_resize.py | 60 + .../imagenet_bs64_pil_resize_autoaug.py | 78 + .../_base_/datasets/imagenet_bs64_swin_224.py | 89 + .../_base_/datasets/imagenet_bs64_swin_256.py | 89 + .../_base_/datasets/imagenet_bs64_swin_384.py | 64 + mmpretrain/configs/_base_/default_runtime.py | 61 + .../configs/_base_/models/convnext_base.py | 25 + .../_base_/models/mae_hivit_base_p16.py | 28 + .../configs/_base_/models/mae_vit_base_p16.py | 28 + .../configs/_base_/models/mobilenet_v2_1x.py | 17 + .../_base_/models/mobilenet_v3_small.py | 25 + mmpretrain/configs/_base_/models/resnet18.py | 22 + .../_base_/models/swin_transformer_base.py | 20 + .../_base_/models/swin_transformer_v2_base.py | 19 + .../configs/_base_/models/vit_base_p16.py | 31 + .../configs/_base_/schedules/cifar10_bs128.py | 20 + .../configs/_base_/schedules/cub_bs64.py | 39 + .../schedules/imagenet_bs1024_adamw_swin.py | 46 + .../_base_/schedules/imagenet_bs256.py | 21 + .../schedules/imagenet_bs256_epochstep.py | 20 + .../_base_/schedules/imagenet_bs4096_adamw.py | 44 + .../schedules/imagenet_lars_coslr_200e.py | 27 + ...eit_base_p16_8xb256_amp_coslr_300e_in1k.py | 146 ++ .../beit-base-p16_8xb128-coslr-100e_in1k.py | 139 ++ .../benchmarks/beit-base-p16_8xb64_in1k.py | 50 + ...it-base-p16_8xb256-amp-coslr-1600e_in1k.py | 130 ++ ...eit-base-p16_8xb256-amp-coslr-300e_in1k.py | 130 ++ .../beit-base-p16_8xb128-coslr-100e_in1k.py | 132 ++ .../benchmarks/beit-base-p16_8xb64_in1k.py | 42 + .../convnext/convnext-base_32xb128_in1k.py | 28 + .../convnext/convnext-base_32xb128_in21k.py | 27 + .../convnext-large_64xb64_in1k-384px.py | 27 + .../convnext/convnext-large_64xb64_in1k.py | 27 + .../convnext/convnext-large_64xb64_in21k.py | 26 + .../convnext-small_32xb128_in1k-384px.py | 27 + .../convnext/convnext-small_32xb128_in1k.py | 27 + .../convnext-tiny_32xb128_in1k-384px.py | 27 + .../convnext/convnext-tiny_32xb128_in1k.py | 27 + .../convnext-xlarge_64xb64_in1k-384px.py | 27 + .../convnext/convnext-xlarge_64xb64_in1k.py | 27 + .../convnext/convnext-xlarge_64xb64_in21k.py | 28 + .../convnext_base_32xb128_in1k_384px.py | 28 + ...le_vit_base_p16_16xb256_coslr_400e_in1k.py | 92 + ...it_base_p16_8xb512_amp_coslr_1600e_in1k.py | 65 + ...vit_base_p16_8xb512_amp_coslr_400e_in1k.py | 65 + ...vit_base_p16_8xb512_amp_coslr_800e_in1k.py | 65 + ...t_large_p16_8xb512_amp_coslr_1600e_in1k.py | 70 + ...it_large_p16_8xb512_amp_coslr_400e_in1k.py | 70 + ...it_large_p16_8xb512_amp_coslr_800e_in1k.py | 70 + ...it_base_p16_8xb512_amp_coslr_1600e_in1k.py | 65 + ...vit_base_p16_8xb512_amp_coslr_300e_in1k.py | 65 + ...vit_base_p16_8xb512_amp_coslr_400e_in1k.py | 65 + ...vit_base_p16_8xb512_amp_coslr_800e_in1k.py | 65 + ...it_huge_p14_8xb512_amp_coslr_1600e_in1k.py | 75 + ...t_large_p16_8xb512_amp_coslr_1600e_in1k.py | 70 + ...it_large_p16_8xb512_amp_coslr_300e_in1k.py | 70 + ...it_large_p16_8xb512_amp_coslr_400e_in1k.py | 70 + ...it_large_p16_8xb512_amp_coslr_800e_in1k.py | 70 + .../mobilenet_v2/mobilenet_v2_8xb32_in1k.py | 9 + .../mobilenet_v3_large_8xb128_in1k.py | 40 + .../mobilenet_v3_small_050_8xb128_in1k.py | 85 + .../mobilenet_v3_small_075_8xb128_in1k.py | 83 + .../mobilenet_v3_small_8xb128_in1k.py | 34 + .../mobilenet_v3_small_8xb16_cifar10.py | 34 + .../configs/resnet/resnet18_8xb32_in1k.py | 9 + ...simclr_resnet50_16xb256_coslr_200e_in1k.py | 58 + .../swin_transformer/swin_base_16xb64_in1k.py | 35 + .../swin_base_16xb64_in1k_384px.py | 12 + .../swin_large_16xb64_in1k.py | 18 + .../swin_large_16xb64_in1k_384px.py | 18 + .../swin_large_8xb8_cub_384px.py | 49 + .../swin_small_16xb64_in1k.py | 37 + .../swin_transformer/swin_tiny_16xb64_in1k.py | 37 + .../swinv2_base_w12_8xb128_in21k_192px.py | 32 + .../swinv2_base_w16_16xb64_in1k_256px.py | 24 + ...v2_base_w16_in21k_pre_16xb64_in1k_256px.py | 26 + ...v2_base_w24_in21k_pre_16xb64_in1k_384px.py | 14 + .../swinv2_base_w8_16xb64_in1k_256px.py | 23 + .../swinv2_large_w12_8xb128_in21k_192px.py | 32 + ...2_large_w16_in21k_pre_16xb64_in1k_256px.py | 24 + ...2_large_w24_in21k_pre_16xb64_in1k_384px.py | 24 + .../swinv2_small_w16_16xb64_in1k_256px.py | 28 + .../swinv2_small_w8_16xb64_in1k_256px.py | 24 + .../swinv2_tiny_w16_16xb64_in1k_256px.py | 28 + .../swinv2_tiny_w8_16xb64_in1k_256px.py | 24 + .../vit_base_p16_32xb128_mae_in1k.py | 52 + .../vit_base_p16_64xb64_in1k.py | 20 + .../vit_base_p16_64xb64_in1k_384px.py | 44 + .../vit_base_p32_64xb64_in1k.py | 26 + .../vit_base_p32_64xb64_in1k_384px.py | 48 + .../vit_large_p16_64xb64_in1k.py | 27 + .../vit_large_p16_64xb64_in1k_384px.py | 49 + .../vit_large_p32_64xb64_in1k.py | 27 + .../vit_large_p32_64xb64_in1k_384px.py | 49 + mmpretrain/datasets/__init__.py | 62 + mmpretrain/datasets/base_dataset.py | 219 ++ mmpretrain/datasets/builder.py | 25 + mmpretrain/datasets/caltech101.py | 113 ++ mmpretrain/datasets/categories.py | 1661 +++++++++++++++ mmpretrain/datasets/cifar.py | 210 ++ mmpretrain/datasets/coco_caption.py | 42 + mmpretrain/datasets/coco_retrieval.py | 148 ++ mmpretrain/datasets/coco_vqa.py | 114 ++ mmpretrain/datasets/cub.py | 142 ++ mmpretrain/datasets/custom.py | 287 +++ mmpretrain/datasets/dataset_wrappers.py | 176 ++ mmpretrain/datasets/dtd.py | 116 ++ mmpretrain/datasets/fgvcaircraft.py | 98 + mmpretrain/datasets/flamingo.py | 295 +++ mmpretrain/datasets/flickr30k_caption.py | 77 + mmpretrain/datasets/flickr30k_retrieval.py | 110 + mmpretrain/datasets/flowers102.py | 104 + mmpretrain/datasets/food101.py | 102 + mmpretrain/datasets/gqa_dataset.py | 70 + mmpretrain/datasets/iconqa.py | 63 + mmpretrain/datasets/imagenet.py | 235 +++ mmpretrain/datasets/infographic_vqa.py | 61 + mmpretrain/datasets/inshop.py | 157 ++ mmpretrain/datasets/minigpt4_dataset.py | 79 + mmpretrain/datasets/mnist.py | 234 +++ mmpretrain/datasets/multi_label.py | 85 + mmpretrain/datasets/multi_task.py | 337 ++++ mmpretrain/datasets/nlvr2.py | 36 + mmpretrain/datasets/nocaps.py | 46 + mmpretrain/datasets/ocr_vqa.py | 91 + mmpretrain/datasets/oxfordiiitpet.py | 97 + mmpretrain/datasets/places205.py | 40 + mmpretrain/datasets/refcoco.py | 112 + mmpretrain/datasets/samplers/__init__.py | 5 + mmpretrain/datasets/samplers/repeat_aug.py | 101 + mmpretrain/datasets/samplers/sequential.py | 56 + mmpretrain/datasets/scienceqa.py | 109 + mmpretrain/datasets/stanfordcars.py | 148 ++ mmpretrain/datasets/sun397.py | 125 ++ mmpretrain/datasets/textvqa.py | 105 + mmpretrain/datasets/transforms/__init__.py | 41 + .../datasets/transforms/auto_augment.py | 1244 ++++++++++++ mmpretrain/datasets/transforms/formatting.py | 353 ++++ mmpretrain/datasets/transforms/processing.py | 1795 +++++++++++++++++ mmpretrain/datasets/transforms/utils.py | 53 + mmpretrain/datasets/transforms/wrappers.py | 144 ++ mmpretrain/datasets/utils.py | 243 +++ mmpretrain/datasets/vg_vqa.py | 77 + mmpretrain/datasets/visual_genome.py | 95 + mmpretrain/datasets/vizwiz.py | 112 + mmpretrain/datasets/voc.py | 195 ++ mmpretrain/datasets/vsr.py | 55 + mmpretrain/engine/__init__.py | 5 + mmpretrain/engine/hooks/__init__.py | 19 + .../engine/hooks/class_num_check_hook.py | 63 + mmpretrain/engine/hooks/densecl_hook.py | 42 + mmpretrain/engine/hooks/ema_hook.py | 216 ++ mmpretrain/engine/hooks/margin_head_hooks.py | 61 + mmpretrain/engine/hooks/precise_bn_hook.py | 223 ++ mmpretrain/engine/hooks/retriever_hooks.py | 32 + mmpretrain/engine/hooks/simsiam_hook.py | 48 + mmpretrain/engine/hooks/swav_hook.py | 119 ++ mmpretrain/engine/hooks/switch_recipe_hook.py | 169 ++ mmpretrain/engine/hooks/visualization_hook.py | 126 ++ mmpretrain/engine/hooks/warmup_param_hook.py | 66 + mmpretrain/engine/optimizers/__init__.py | 8 + mmpretrain/engine/optimizers/adan_t.py | 312 +++ mmpretrain/engine/optimizers/lamb.py | 228 +++ mmpretrain/engine/optimizers/lars.py | 130 ++ .../layer_decay_optim_wrapper_constructor.py | 166 ++ mmpretrain/engine/runners/__init__.py | 4 + mmpretrain/engine/runners/retrieval_loop.py | 168 ++ mmpretrain/engine/schedulers/__init__.py | 4 + .../schedulers/weight_decay_scheduler.py | 64 + mmpretrain/evaluation/__init__.py | 3 + mmpretrain/evaluation/functional/__init__.py | 1 + mmpretrain/evaluation/metrics/ANLS.py | 103 + mmpretrain/evaluation/metrics/__init__.py | 22 + mmpretrain/evaluation/metrics/caption.py | 136 ++ mmpretrain/evaluation/metrics/gqa.py | 78 + mmpretrain/evaluation/metrics/multi_label.py | 599 ++++++ mmpretrain/evaluation/metrics/multi_task.py | 120 ++ mmpretrain/evaluation/metrics/nocaps.py | 59 + mmpretrain/evaluation/metrics/retrieval.py | 445 ++++ mmpretrain/evaluation/metrics/scienceqa.py | 170 ++ .../evaluation/metrics/shape_bias_label.py | 172 ++ mmpretrain/evaluation/metrics/single_label.py | 776 +++++++ .../metrics/visual_grounding_eval.py | 85 + .../evaluation/metrics/voc_multi_label.py | 98 + mmpretrain/evaluation/metrics/vqa.py | 315 +++ mmpretrain/models/__init__.py | 20 + mmpretrain/models/backbones/__init__.py | 129 ++ mmpretrain/models/backbones/alexnet.py | 56 + mmpretrain/models/backbones/base_backbone.py | 33 + mmpretrain/models/backbones/beit.py | 697 +++++++ mmpretrain/models/backbones/conformer.py | 621 ++++++ mmpretrain/models/backbones/convmixer.py | 176 ++ mmpretrain/models/backbones/convnext.py | 412 ++++ mmpretrain/models/backbones/cspnet.py | 679 +++++++ mmpretrain/models/backbones/davit.py | 834 ++++++++ mmpretrain/models/backbones/deit.py | 116 ++ mmpretrain/models/backbones/deit3.py | 454 +++++ mmpretrain/models/backbones/densenet.py | 332 +++ mmpretrain/models/backbones/edgenext.py | 398 ++++ .../models/backbones/efficientformer.py | 606 ++++++ mmpretrain/models/backbones/efficientnet.py | 410 ++++ .../models/backbones/efficientnet_v2.py | 343 ++++ mmpretrain/models/backbones/hivit.py | 656 ++++++ mmpretrain/models/backbones/hornet.py | 500 +++++ mmpretrain/models/backbones/hrnet.py | 563 ++++++ mmpretrain/models/backbones/inception_v3.py | 501 +++++ mmpretrain/models/backbones/lenet.py | 42 + mmpretrain/models/backbones/levit.py | 522 +++++ mmpretrain/models/backbones/mixmim.py | 533 +++++ mmpretrain/models/backbones/mlp_mixer.py | 263 +++ mmpretrain/models/backbones/mobilenet_v2.py | 264 +++ mmpretrain/models/backbones/mobilenet_v3.py | 217 ++ mmpretrain/models/backbones/mobileone.py | 515 +++++ mmpretrain/models/backbones/mobilevit.py | 431 ++++ mmpretrain/models/backbones/mvit.py | 700 +++++++ mmpretrain/models/backbones/poolformer.py | 416 ++++ mmpretrain/models/backbones/regnet.py | 312 +++ mmpretrain/models/backbones/replknet.py | 668 ++++++ mmpretrain/models/backbones/repmlp.py | 578 ++++++ mmpretrain/models/backbones/repvgg.py | 622 ++++++ mmpretrain/models/backbones/res2net.py | 317 +++ mmpretrain/models/backbones/resnest.py | 339 ++++ mmpretrain/models/backbones/resnet.py | 768 +++++++ mmpretrain/models/backbones/resnet_cifar.py | 81 + mmpretrain/models/backbones/resnext.py | 148 ++ mmpretrain/models/backbones/revvit.py | 671 ++++++ mmpretrain/models/backbones/riformer.py | 390 ++++ mmpretrain/models/backbones/seresnet.py | 125 ++ mmpretrain/models/backbones/seresnext.py | 155 ++ mmpretrain/models/backbones/shufflenet_v1.py | 321 +++ mmpretrain/models/backbones/shufflenet_v2.py | 305 +++ .../models/backbones/sparse_convnext.py | 298 +++ mmpretrain/models/backbones/sparse_resnet.py | 179 ++ .../models/backbones/swin_transformer.py | 585 ++++++ .../models/backbones/swin_transformer_v2.py | 567 ++++++ mmpretrain/models/backbones/t2t_vit.py | 447 ++++ mmpretrain/models/backbones/timm_backbone.py | 111 + mmpretrain/models/backbones/tinyvit.py | 769 +++++++ mmpretrain/models/backbones/tnt.py | 368 ++++ mmpretrain/models/backbones/twins.py | 721 +++++++ mmpretrain/models/backbones/van.py | 434 ++++ mmpretrain/models/backbones/vgg.py | 183 ++ mmpretrain/models/backbones/vig.py | 852 ++++++++ .../models/backbones/vision_transformer.py | 537 +++++ mmpretrain/models/backbones/vit_eva02.py | 350 ++++ mmpretrain/models/backbones/vit_sam.py | 697 +++++++ mmpretrain/models/backbones/xcit.py | 770 +++++++ mmpretrain/models/builder.py | 39 + mmpretrain/models/classifiers/__init__.py | 10 + mmpretrain/models/classifiers/base.py | 108 + mmpretrain/models/classifiers/hugging_face.py | 222 ++ mmpretrain/models/classifiers/image.py | 265 +++ mmpretrain/models/classifiers/timm.py | 209 ++ mmpretrain/models/heads/__init__.py | 69 + mmpretrain/models/heads/beitv1_head.py | 55 + mmpretrain/models/heads/beitv2_head.py | 57 + mmpretrain/models/heads/cae_head.py | 69 + mmpretrain/models/heads/cls_head.py | 156 ++ mmpretrain/models/heads/conformer_head.py | 122 ++ mmpretrain/models/heads/contrastive_head.py | 50 + mmpretrain/models/heads/deit_head.py | 72 + .../models/heads/efficientformer_head.py | 89 + mmpretrain/models/heads/grounding_head.py | 217 ++ mmpretrain/models/heads/itc_head.py | 157 ++ mmpretrain/models/heads/itm_head.py | 117 ++ mmpretrain/models/heads/itpn_clip_head.py | 56 + mmpretrain/models/heads/latent_heads.py | 94 + mmpretrain/models/heads/levit_head.py | 81 + mmpretrain/models/heads/linear_head.py | 63 + mmpretrain/models/heads/mae_head.py | 106 + mmpretrain/models/heads/margin_head.py | 300 +++ mmpretrain/models/heads/mim_head.py | 37 + mmpretrain/models/heads/mixmim_head.py | 49 + mmpretrain/models/heads/mocov3_head.py | 66 + .../models/heads/multi_label_cls_head.py | 155 ++ .../models/heads/multi_label_csra_head.py | 112 + .../models/heads/multi_label_linear_head.py | 66 + mmpretrain/models/heads/multi_task_head.py | 153 ++ mmpretrain/models/heads/seq_gen_head.py | 188 ++ mmpretrain/models/heads/simmim_head.py | 40 + mmpretrain/models/heads/spark_head.py | 92 + mmpretrain/models/heads/stacked_head.py | 135 ++ mmpretrain/models/heads/swav_head.py | 31 + mmpretrain/models/heads/vig_head.py | 65 + .../models/heads/vision_transformer_head.py | 97 + mmpretrain/models/heads/vqa_head.py | 246 +++ mmpretrain/models/losses/__init__.py | 35 + mmpretrain/models/losses/asymmetric_loss.py | 149 ++ mmpretrain/models/losses/cae_loss.py | 48 + .../models/losses/cosine_similarity_loss.py | 55 + .../models/losses/cross_correlation_loss.py | 44 + .../models/losses/cross_entropy_loss.py | 209 ++ mmpretrain/models/losses/focal_loss.py | 116 ++ mmpretrain/models/losses/label_smooth_loss.py | 177 ++ .../models/losses/reconstruction_loss.py | 67 + mmpretrain/models/losses/seesaw_loss.py | 173 ++ mmpretrain/models/losses/swav_loss.py | 190 ++ mmpretrain/models/losses/utils.py | 119 ++ mmpretrain/models/multimodal/__init__.py | 24 + mmpretrain/models/multimodal/blip/__init__.py | 12 + .../models/multimodal/blip/blip_caption.py | 184 ++ .../models/multimodal/blip/blip_grounding.py | 248 +++ .../models/multimodal/blip/blip_nlvr.py | 205 ++ .../models/multimodal/blip/blip_retrieval.py | 716 +++++++ mmpretrain/models/multimodal/blip/blip_vqa.py | 265 +++ .../models/multimodal/blip/language_model.py | 1320 ++++++++++++ mmpretrain/models/multimodal/blip2/Qformer.py | 773 +++++++ .../models/multimodal/blip2/__init__.py | 10 + .../models/multimodal/blip2/blip2_caption.py | 315 +++ .../models/multimodal/blip2/blip2_opt_vqa.py | 92 + .../multimodal/blip2/blip2_retriever.py | 505 +++++ .../models/multimodal/blip2/modeling_opt.py | 1083 ++++++++++ .../multimodal/chinese_clip/__init__.py | 5 + .../models/multimodal/chinese_clip/bert.py | 263 +++ .../multimodal/chinese_clip/chinese_clip.py | 446 ++++ .../models/multimodal/chinese_clip/utils.py | 186 ++ mmpretrain/models/multimodal/clip/__init__.py | 5 + mmpretrain/models/multimodal/clip/clip.py | 364 ++++ .../multimodal/clip/clip_transformer.py | 99 + mmpretrain/models/multimodal/clip/utils.py | 115 ++ .../models/multimodal/flamingo/__init__.py | 5 + .../models/multimodal/flamingo/adapter.py | 96 + .../models/multimodal/flamingo/flamingo.py | 323 +++ .../models/multimodal/flamingo/modules.py | 398 ++++ .../models/multimodal/flamingo/utils.py | 64 + .../models/multimodal/llava/__init__.py | 5 + mmpretrain/models/multimodal/llava/llava.py | 267 +++ mmpretrain/models/multimodal/llava/modules.py | 234 +++ .../models/multimodal/minigpt4/__init__.py | 4 + .../models/multimodal/minigpt4/minigpt4.py | 410 ++++ mmpretrain/models/multimodal/ofa/__init__.py | 5 + mmpretrain/models/multimodal/ofa/ofa.py | 320 +++ .../models/multimodal/ofa/ofa_modules.py | 1613 +++++++++++++++ .../models/multimodal/otter/__init__.py | 4 + mmpretrain/models/multimodal/otter/otter.py | 143 ++ mmpretrain/models/multimodal/ram/__init__.py | 4 + mmpretrain/models/multimodal/ram/bert.py | 1197 +++++++++++ .../models/multimodal/ram/config/__init__.py | 1 + .../ram/config/ram_swin_large_14m.py | 93 + .../multimodal/ram/data/ram_tag_list.pickle | Bin 0 -> 51099 bytes .../ram/data/ram_tag_list_chinese.pickle | Bin 0 -> 50796 bytes .../ram/data/ram_tag_list_threshold.pickle | Bin 0 -> 41289 bytes .../models/multimodal/ram/gradio_demo.py | 109 + .../models/multimodal/ram/openset_utils.py | 212 ++ mmpretrain/models/multimodal/ram/ram.py | 332 +++ .../models/multimodal/ram/run/__init__.py | 1 + .../models/multimodal/ram/run/inference.py | 29 + mmpretrain/models/multimodal/ram/utils.py | 87 + mmpretrain/models/necks/__init__.py | 37 + mmpretrain/models/necks/beitv2_neck.py | 153 ++ mmpretrain/models/necks/cae_neck.py | 273 +++ mmpretrain/models/necks/densecl_neck.py | 71 + mmpretrain/models/necks/gap.py | 45 + mmpretrain/models/necks/gem.py | 53 + mmpretrain/models/necks/hr_fuse.py | 83 + mmpretrain/models/necks/itpn_neck.py | 388 ++++ mmpretrain/models/necks/linear_neck.py | 88 + mmpretrain/models/necks/mae_neck.py | 188 ++ mmpretrain/models/necks/milan_neck.py | 222 ++ mmpretrain/models/necks/mixmim_neck.py | 111 + mmpretrain/models/necks/mocov2_neck.py | 52 + mmpretrain/models/necks/nonlinear_neck.py | 115 ++ mmpretrain/models/necks/simmim_neck.py | 33 + mmpretrain/models/necks/spark_neck.py | 169 ++ mmpretrain/models/necks/swav_neck.py | 93 + mmpretrain/models/peft/__init__.py | 6 + mmpretrain/models/peft/lora.py | 205 ++ mmpretrain/models/retrievers/__init__.py | 5 + mmpretrain/models/retrievers/base.py | 151 ++ mmpretrain/models/retrievers/image2image.py | 314 +++ mmpretrain/models/selfsup/__init__.py | 59 + mmpretrain/models/selfsup/barlowtwins.py | 42 + mmpretrain/models/selfsup/base.py | 179 ++ mmpretrain/models/selfsup/beit.py | 357 ++++ mmpretrain/models/selfsup/byol.py | 89 + mmpretrain/models/selfsup/cae.py | 472 +++++ mmpretrain/models/selfsup/densecl.py | 203 ++ mmpretrain/models/selfsup/eva.py | 43 + mmpretrain/models/selfsup/itpn.py | 359 ++++ mmpretrain/models/selfsup/mae.py | 416 ++++ mmpretrain/models/selfsup/maskfeat.py | 336 +++ mmpretrain/models/selfsup/mff.py | 194 ++ mmpretrain/models/selfsup/milan.py | 202 ++ mmpretrain/models/selfsup/mixmim.py | 263 +++ mmpretrain/models/selfsup/moco.py | 137 ++ mmpretrain/models/selfsup/mocov3.py | 215 ++ mmpretrain/models/selfsup/simclr.py | 98 + mmpretrain/models/selfsup/simmim.py | 194 ++ mmpretrain/models/selfsup/simsiam.py | 43 + mmpretrain/models/selfsup/spark.py | 163 ++ mmpretrain/models/selfsup/swav.py | 49 + mmpretrain/models/tta/__init__.py | 4 + mmpretrain/models/tta/score_tta.py | 36 + mmpretrain/models/utils/__init__.py | 102 + mmpretrain/models/utils/attention.py | 1129 +++++++++++ .../models/utils/batch_augments/__init__.py | 7 + .../models/utils/batch_augments/cutmix.py | 157 ++ .../models/utils/batch_augments/mixup.py | 65 + .../models/utils/batch_augments/resizemix.py | 95 + .../models/utils/batch_augments/wrapper.py | 74 + mmpretrain/models/utils/batch_shuffle.py | 66 + mmpretrain/models/utils/box_utils.py | 56 + mmpretrain/models/utils/channel_shuffle.py | 29 + .../models/utils/clip_generator_helper.py | 394 ++++ mmpretrain/models/utils/data_preprocessor.py | 620 ++++++ mmpretrain/models/utils/ema.py | 87 + mmpretrain/models/utils/embed.py | 423 ++++ mmpretrain/models/utils/helpers.py | 53 + mmpretrain/models/utils/huggingface.py | 100 + mmpretrain/models/utils/inverted_residual.py | 125 ++ mmpretrain/models/utils/layer_scale.py | 40 + mmpretrain/models/utils/make_divisible.py | 25 + mmpretrain/models/utils/norm.py | 133 ++ mmpretrain/models/utils/position_encoding.py | 247 +++ .../models/utils/res_layer_extra_norm.py | 31 + mmpretrain/models/utils/se_layer.py | 80 + mmpretrain/models/utils/sparse_modules.py | 149 ++ mmpretrain/models/utils/swiglu_ffn.py | 98 + mmpretrain/models/utils/tokenizer.py | 188 ++ mmpretrain/models/utils/vector_quantizer.py | 232 +++ mmpretrain/registry.py | 195 ++ mmpretrain/structures/__init__.py | 10 + mmpretrain/structures/data_sample.py | 167 ++ .../structures/multi_task_data_sample.py | 10 + mmpretrain/structures/utils.py | 153 ++ mmpretrain/utils/__init__.py | 12 + mmpretrain/utils/analyze.py | 43 + mmpretrain/utils/collect_env.py | 16 + mmpretrain/utils/dependency.py | 82 + mmpretrain/utils/misc.py | 18 + mmpretrain/utils/progress.py | 40 + mmpretrain/utils/setup_env.py | 41 + mmpretrain/version.py | 28 + mmpretrain/visualization/__init__.py | 5 + mmpretrain/visualization/utils.py | 60 + mmpretrain/visualization/visualizer.py | 777 +++++++ 458 files changed, 82672 insertions(+) create mode 100644 mmpretrain/__init__.py create mode 100644 mmpretrain/apis/__init__.py create mode 100644 mmpretrain/apis/base.py create mode 100644 mmpretrain/apis/feature_extractor.py create mode 100644 mmpretrain/apis/image_caption.py create mode 100644 mmpretrain/apis/image_classification.py create mode 100644 mmpretrain/apis/image_retrieval.py create mode 100644 mmpretrain/apis/model.py create mode 100644 mmpretrain/apis/multimodal_retrieval.py create mode 100644 mmpretrain/apis/nlvr.py create mode 100644 mmpretrain/apis/utils.py create mode 100644 mmpretrain/apis/visual_grounding.py create mode 100644 mmpretrain/apis/visual_question_answering.py create mode 100644 mmpretrain/configs/_base_/datasets/cifar10_bs16.py create mode 100644 mmpretrain/configs/_base_/datasets/cub_bs8_384.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet21k_bs128.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs256_beitv2.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs32.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_224.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_256.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py create mode 100644 mmpretrain/configs/_base_/default_runtime.py create mode 100644 mmpretrain/configs/_base_/models/convnext_base.py create mode 100644 mmpretrain/configs/_base_/models/mae_hivit_base_p16.py create mode 100644 mmpretrain/configs/_base_/models/mae_vit_base_p16.py create mode 100644 mmpretrain/configs/_base_/models/mobilenet_v2_1x.py create mode 100644 mmpretrain/configs/_base_/models/mobilenet_v3_small.py create mode 100644 mmpretrain/configs/_base_/models/resnet18.py create mode 100644 mmpretrain/configs/_base_/models/swin_transformer_base.py create mode 100644 mmpretrain/configs/_base_/models/swin_transformer_v2_base.py create mode 100644 mmpretrain/configs/_base_/models/vit_base_p16.py create mode 100644 mmpretrain/configs/_base_/schedules/cifar10_bs128.py create mode 100644 mmpretrain/configs/_base_/schedules/cub_bs64.py create mode 100644 mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py create mode 100644 mmpretrain/configs/_base_/schedules/imagenet_bs256.py create mode 100644 mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py create mode 100644 mmpretrain/configs/_base_/schedules/imagenet_bs4096_adamw.py create mode 100644 mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py create mode 100644 mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py create mode 100644 mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py create mode 100644 mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb64_in1k.py create mode 100644 mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-1600e_in1k.py create mode 100644 mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k.py create mode 100644 mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py create mode 100644 mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb64_in1k.py create mode 100644 mmpretrain/configs/convnext/convnext-base_32xb128_in1k.py create mode 100644 mmpretrain/configs/convnext/convnext-base_32xb128_in21k.py create mode 100644 mmpretrain/configs/convnext/convnext-large_64xb64_in1k-384px.py create mode 100644 mmpretrain/configs/convnext/convnext-large_64xb64_in1k.py create mode 100644 mmpretrain/configs/convnext/convnext-large_64xb64_in21k.py create mode 100644 mmpretrain/configs/convnext/convnext-small_32xb128_in1k-384px.py create mode 100644 mmpretrain/configs/convnext/convnext-small_32xb128_in1k.py create mode 100644 mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k-384px.py create mode 100644 mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k.py create mode 100644 mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k-384px.py create mode 100644 mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k.py create mode 100644 mmpretrain/configs/convnext/convnext-xlarge_64xb64_in21k.py create mode 100644 mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py create mode 100644 mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_1600e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_400e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_800e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_1600e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_400e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_800e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_1600e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_400e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_800e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_vit_huge_p14_8xb512_amp_coslr_1600e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_1600e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_300e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_400e_in1k.py create mode 100644 mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_800e_in1k.py create mode 100644 mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py create mode 100644 mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py create mode 100644 mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py create mode 100644 mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py create mode 100644 mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py create mode 100644 mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py create mode 100644 mmpretrain/configs/resnet/resnet18_8xb32_in1k.py create mode 100644 mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py create mode 100644 mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k.py create mode 100644 mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k_384px.py create mode 100644 mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k.py create mode 100644 mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k_384px.py create mode 100644 mmpretrain/configs/swin_transformer/swin_large_8xb8_cub_384px.py create mode 100644 mmpretrain/configs/swin_transformer/swin_small_16xb64_in1k.py create mode 100644 mmpretrain/configs/swin_transformer/swin_tiny_16xb64_in1k.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_base_w12_8xb128_in21k_192px.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_16xb64_in1k_256px.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_in21k_pre_16xb64_in1k_256px.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_base_w24_in21k_pre_16xb64_in1k_384px.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_base_w8_16xb64_in1k_256px.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_large_w12_8xb128_in21k_192px.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_large_w16_in21k_pre_16xb64_in1k_256px.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_large_w24_in21k_pre_16xb64_in1k_384px.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_small_w16_16xb64_in1k_256px.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_small_w8_16xb64_in1k_256px.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w16_16xb64_in1k_256px.py create mode 100644 mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w8_16xb64_in1k_256px.py create mode 100644 mmpretrain/configs/vision_transformer/vit_base_p16_32xb128_mae_in1k.py create mode 100644 mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k.py create mode 100644 mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k_384px.py create mode 100644 mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k.py create mode 100644 mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k_384px.py create mode 100644 mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k.py create mode 100644 mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k_384px.py create mode 100644 mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k.py create mode 100644 mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k_384px.py create mode 100644 mmpretrain/datasets/__init__.py create mode 100644 mmpretrain/datasets/base_dataset.py create mode 100644 mmpretrain/datasets/builder.py create mode 100644 mmpretrain/datasets/caltech101.py create mode 100644 mmpretrain/datasets/categories.py create mode 100644 mmpretrain/datasets/cifar.py create mode 100644 mmpretrain/datasets/coco_caption.py create mode 100644 mmpretrain/datasets/coco_retrieval.py create mode 100644 mmpretrain/datasets/coco_vqa.py create mode 100644 mmpretrain/datasets/cub.py create mode 100644 mmpretrain/datasets/custom.py create mode 100644 mmpretrain/datasets/dataset_wrappers.py create mode 100644 mmpretrain/datasets/dtd.py create mode 100644 mmpretrain/datasets/fgvcaircraft.py create mode 100644 mmpretrain/datasets/flamingo.py create mode 100644 mmpretrain/datasets/flickr30k_caption.py create mode 100644 mmpretrain/datasets/flickr30k_retrieval.py create mode 100644 mmpretrain/datasets/flowers102.py create mode 100644 mmpretrain/datasets/food101.py create mode 100644 mmpretrain/datasets/gqa_dataset.py create mode 100644 mmpretrain/datasets/iconqa.py create mode 100644 mmpretrain/datasets/imagenet.py create mode 100644 mmpretrain/datasets/infographic_vqa.py create mode 100644 mmpretrain/datasets/inshop.py create mode 100644 mmpretrain/datasets/minigpt4_dataset.py create mode 100644 mmpretrain/datasets/mnist.py create mode 100644 mmpretrain/datasets/multi_label.py create mode 100644 mmpretrain/datasets/multi_task.py create mode 100644 mmpretrain/datasets/nlvr2.py create mode 100644 mmpretrain/datasets/nocaps.py create mode 100644 mmpretrain/datasets/ocr_vqa.py create mode 100644 mmpretrain/datasets/oxfordiiitpet.py create mode 100644 mmpretrain/datasets/places205.py create mode 100644 mmpretrain/datasets/refcoco.py create mode 100644 mmpretrain/datasets/samplers/__init__.py create mode 100644 mmpretrain/datasets/samplers/repeat_aug.py create mode 100644 mmpretrain/datasets/samplers/sequential.py create mode 100644 mmpretrain/datasets/scienceqa.py create mode 100644 mmpretrain/datasets/stanfordcars.py create mode 100644 mmpretrain/datasets/sun397.py create mode 100644 mmpretrain/datasets/textvqa.py create mode 100644 mmpretrain/datasets/transforms/__init__.py create mode 100644 mmpretrain/datasets/transforms/auto_augment.py create mode 100644 mmpretrain/datasets/transforms/formatting.py create mode 100644 mmpretrain/datasets/transforms/processing.py create mode 100644 mmpretrain/datasets/transforms/utils.py create mode 100644 mmpretrain/datasets/transforms/wrappers.py create mode 100644 mmpretrain/datasets/utils.py create mode 100644 mmpretrain/datasets/vg_vqa.py create mode 100644 mmpretrain/datasets/visual_genome.py create mode 100644 mmpretrain/datasets/vizwiz.py create mode 100644 mmpretrain/datasets/voc.py create mode 100644 mmpretrain/datasets/vsr.py create mode 100644 mmpretrain/engine/__init__.py create mode 100644 mmpretrain/engine/hooks/__init__.py create mode 100644 mmpretrain/engine/hooks/class_num_check_hook.py create mode 100644 mmpretrain/engine/hooks/densecl_hook.py create mode 100644 mmpretrain/engine/hooks/ema_hook.py create mode 100644 mmpretrain/engine/hooks/margin_head_hooks.py create mode 100644 mmpretrain/engine/hooks/precise_bn_hook.py create mode 100644 mmpretrain/engine/hooks/retriever_hooks.py create mode 100644 mmpretrain/engine/hooks/simsiam_hook.py create mode 100644 mmpretrain/engine/hooks/swav_hook.py create mode 100644 mmpretrain/engine/hooks/switch_recipe_hook.py create mode 100644 mmpretrain/engine/hooks/visualization_hook.py create mode 100644 mmpretrain/engine/hooks/warmup_param_hook.py create mode 100644 mmpretrain/engine/optimizers/__init__.py create mode 100644 mmpretrain/engine/optimizers/adan_t.py create mode 100644 mmpretrain/engine/optimizers/lamb.py create mode 100644 mmpretrain/engine/optimizers/lars.py create mode 100644 mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py create mode 100644 mmpretrain/engine/runners/__init__.py create mode 100644 mmpretrain/engine/runners/retrieval_loop.py create mode 100644 mmpretrain/engine/schedulers/__init__.py create mode 100644 mmpretrain/engine/schedulers/weight_decay_scheduler.py create mode 100644 mmpretrain/evaluation/__init__.py create mode 100644 mmpretrain/evaluation/functional/__init__.py create mode 100644 mmpretrain/evaluation/metrics/ANLS.py create mode 100644 mmpretrain/evaluation/metrics/__init__.py create mode 100644 mmpretrain/evaluation/metrics/caption.py create mode 100644 mmpretrain/evaluation/metrics/gqa.py create mode 100644 mmpretrain/evaluation/metrics/multi_label.py create mode 100644 mmpretrain/evaluation/metrics/multi_task.py create mode 100644 mmpretrain/evaluation/metrics/nocaps.py create mode 100644 mmpretrain/evaluation/metrics/retrieval.py create mode 100644 mmpretrain/evaluation/metrics/scienceqa.py create mode 100644 mmpretrain/evaluation/metrics/shape_bias_label.py create mode 100644 mmpretrain/evaluation/metrics/single_label.py create mode 100644 mmpretrain/evaluation/metrics/visual_grounding_eval.py create mode 100644 mmpretrain/evaluation/metrics/voc_multi_label.py create mode 100644 mmpretrain/evaluation/metrics/vqa.py create mode 100644 mmpretrain/models/__init__.py create mode 100644 mmpretrain/models/backbones/__init__.py create mode 100644 mmpretrain/models/backbones/alexnet.py create mode 100644 mmpretrain/models/backbones/base_backbone.py create mode 100644 mmpretrain/models/backbones/beit.py create mode 100644 mmpretrain/models/backbones/conformer.py create mode 100644 mmpretrain/models/backbones/convmixer.py create mode 100644 mmpretrain/models/backbones/convnext.py create mode 100644 mmpretrain/models/backbones/cspnet.py create mode 100644 mmpretrain/models/backbones/davit.py create mode 100644 mmpretrain/models/backbones/deit.py create mode 100644 mmpretrain/models/backbones/deit3.py create mode 100644 mmpretrain/models/backbones/densenet.py create mode 100644 mmpretrain/models/backbones/edgenext.py create mode 100644 mmpretrain/models/backbones/efficientformer.py create mode 100644 mmpretrain/models/backbones/efficientnet.py create mode 100644 mmpretrain/models/backbones/efficientnet_v2.py create mode 100644 mmpretrain/models/backbones/hivit.py create mode 100644 mmpretrain/models/backbones/hornet.py create mode 100644 mmpretrain/models/backbones/hrnet.py create mode 100644 mmpretrain/models/backbones/inception_v3.py create mode 100644 mmpretrain/models/backbones/lenet.py create mode 100644 mmpretrain/models/backbones/levit.py create mode 100644 mmpretrain/models/backbones/mixmim.py create mode 100644 mmpretrain/models/backbones/mlp_mixer.py create mode 100644 mmpretrain/models/backbones/mobilenet_v2.py create mode 100644 mmpretrain/models/backbones/mobilenet_v3.py create mode 100644 mmpretrain/models/backbones/mobileone.py create mode 100644 mmpretrain/models/backbones/mobilevit.py create mode 100644 mmpretrain/models/backbones/mvit.py create mode 100644 mmpretrain/models/backbones/poolformer.py create mode 100644 mmpretrain/models/backbones/regnet.py create mode 100644 mmpretrain/models/backbones/replknet.py create mode 100644 mmpretrain/models/backbones/repmlp.py create mode 100644 mmpretrain/models/backbones/repvgg.py create mode 100644 mmpretrain/models/backbones/res2net.py create mode 100644 mmpretrain/models/backbones/resnest.py create mode 100644 mmpretrain/models/backbones/resnet.py create mode 100644 mmpretrain/models/backbones/resnet_cifar.py create mode 100644 mmpretrain/models/backbones/resnext.py create mode 100644 mmpretrain/models/backbones/revvit.py create mode 100644 mmpretrain/models/backbones/riformer.py create mode 100644 mmpretrain/models/backbones/seresnet.py create mode 100644 mmpretrain/models/backbones/seresnext.py create mode 100644 mmpretrain/models/backbones/shufflenet_v1.py create mode 100644 mmpretrain/models/backbones/shufflenet_v2.py create mode 100644 mmpretrain/models/backbones/sparse_convnext.py create mode 100644 mmpretrain/models/backbones/sparse_resnet.py create mode 100644 mmpretrain/models/backbones/swin_transformer.py create mode 100644 mmpretrain/models/backbones/swin_transformer_v2.py create mode 100644 mmpretrain/models/backbones/t2t_vit.py create mode 100644 mmpretrain/models/backbones/timm_backbone.py create mode 100644 mmpretrain/models/backbones/tinyvit.py create mode 100644 mmpretrain/models/backbones/tnt.py create mode 100644 mmpretrain/models/backbones/twins.py create mode 100644 mmpretrain/models/backbones/van.py create mode 100644 mmpretrain/models/backbones/vgg.py create mode 100644 mmpretrain/models/backbones/vig.py create mode 100644 mmpretrain/models/backbones/vision_transformer.py create mode 100644 mmpretrain/models/backbones/vit_eva02.py create mode 100644 mmpretrain/models/backbones/vit_sam.py create mode 100644 mmpretrain/models/backbones/xcit.py create mode 100644 mmpretrain/models/builder.py create mode 100644 mmpretrain/models/classifiers/__init__.py create mode 100644 mmpretrain/models/classifiers/base.py create mode 100644 mmpretrain/models/classifiers/hugging_face.py create mode 100644 mmpretrain/models/classifiers/image.py create mode 100644 mmpretrain/models/classifiers/timm.py create mode 100644 mmpretrain/models/heads/__init__.py create mode 100644 mmpretrain/models/heads/beitv1_head.py create mode 100644 mmpretrain/models/heads/beitv2_head.py create mode 100644 mmpretrain/models/heads/cae_head.py create mode 100644 mmpretrain/models/heads/cls_head.py create mode 100644 mmpretrain/models/heads/conformer_head.py create mode 100644 mmpretrain/models/heads/contrastive_head.py create mode 100644 mmpretrain/models/heads/deit_head.py create mode 100644 mmpretrain/models/heads/efficientformer_head.py create mode 100644 mmpretrain/models/heads/grounding_head.py create mode 100644 mmpretrain/models/heads/itc_head.py create mode 100644 mmpretrain/models/heads/itm_head.py create mode 100644 mmpretrain/models/heads/itpn_clip_head.py create mode 100644 mmpretrain/models/heads/latent_heads.py create mode 100644 mmpretrain/models/heads/levit_head.py create mode 100644 mmpretrain/models/heads/linear_head.py create mode 100644 mmpretrain/models/heads/mae_head.py create mode 100644 mmpretrain/models/heads/margin_head.py create mode 100644 mmpretrain/models/heads/mim_head.py create mode 100644 mmpretrain/models/heads/mixmim_head.py create mode 100644 mmpretrain/models/heads/mocov3_head.py create mode 100644 mmpretrain/models/heads/multi_label_cls_head.py create mode 100644 mmpretrain/models/heads/multi_label_csra_head.py create mode 100644 mmpretrain/models/heads/multi_label_linear_head.py create mode 100644 mmpretrain/models/heads/multi_task_head.py create mode 100644 mmpretrain/models/heads/seq_gen_head.py create mode 100644 mmpretrain/models/heads/simmim_head.py create mode 100644 mmpretrain/models/heads/spark_head.py create mode 100644 mmpretrain/models/heads/stacked_head.py create mode 100644 mmpretrain/models/heads/swav_head.py create mode 100644 mmpretrain/models/heads/vig_head.py create mode 100644 mmpretrain/models/heads/vision_transformer_head.py create mode 100644 mmpretrain/models/heads/vqa_head.py create mode 100644 mmpretrain/models/losses/__init__.py create mode 100644 mmpretrain/models/losses/asymmetric_loss.py create mode 100644 mmpretrain/models/losses/cae_loss.py create mode 100644 mmpretrain/models/losses/cosine_similarity_loss.py create mode 100644 mmpretrain/models/losses/cross_correlation_loss.py create mode 100644 mmpretrain/models/losses/cross_entropy_loss.py create mode 100644 mmpretrain/models/losses/focal_loss.py create mode 100644 mmpretrain/models/losses/label_smooth_loss.py create mode 100644 mmpretrain/models/losses/reconstruction_loss.py create mode 100644 mmpretrain/models/losses/seesaw_loss.py create mode 100644 mmpretrain/models/losses/swav_loss.py create mode 100644 mmpretrain/models/losses/utils.py create mode 100644 mmpretrain/models/multimodal/__init__.py create mode 100644 mmpretrain/models/multimodal/blip/__init__.py create mode 100644 mmpretrain/models/multimodal/blip/blip_caption.py create mode 100644 mmpretrain/models/multimodal/blip/blip_grounding.py create mode 100644 mmpretrain/models/multimodal/blip/blip_nlvr.py create mode 100644 mmpretrain/models/multimodal/blip/blip_retrieval.py create mode 100644 mmpretrain/models/multimodal/blip/blip_vqa.py create mode 100644 mmpretrain/models/multimodal/blip/language_model.py create mode 100644 mmpretrain/models/multimodal/blip2/Qformer.py create mode 100644 mmpretrain/models/multimodal/blip2/__init__.py create mode 100644 mmpretrain/models/multimodal/blip2/blip2_caption.py create mode 100644 mmpretrain/models/multimodal/blip2/blip2_opt_vqa.py create mode 100644 mmpretrain/models/multimodal/blip2/blip2_retriever.py create mode 100644 mmpretrain/models/multimodal/blip2/modeling_opt.py create mode 100644 mmpretrain/models/multimodal/chinese_clip/__init__.py create mode 100644 mmpretrain/models/multimodal/chinese_clip/bert.py create mode 100644 mmpretrain/models/multimodal/chinese_clip/chinese_clip.py create mode 100644 mmpretrain/models/multimodal/chinese_clip/utils.py create mode 100644 mmpretrain/models/multimodal/clip/__init__.py create mode 100644 mmpretrain/models/multimodal/clip/clip.py create mode 100644 mmpretrain/models/multimodal/clip/clip_transformer.py create mode 100644 mmpretrain/models/multimodal/clip/utils.py create mode 100644 mmpretrain/models/multimodal/flamingo/__init__.py create mode 100644 mmpretrain/models/multimodal/flamingo/adapter.py create mode 100644 mmpretrain/models/multimodal/flamingo/flamingo.py create mode 100644 mmpretrain/models/multimodal/flamingo/modules.py create mode 100644 mmpretrain/models/multimodal/flamingo/utils.py create mode 100644 mmpretrain/models/multimodal/llava/__init__.py create mode 100644 mmpretrain/models/multimodal/llava/llava.py create mode 100644 mmpretrain/models/multimodal/llava/modules.py create mode 100644 mmpretrain/models/multimodal/minigpt4/__init__.py create mode 100644 mmpretrain/models/multimodal/minigpt4/minigpt4.py create mode 100644 mmpretrain/models/multimodal/ofa/__init__.py create mode 100644 mmpretrain/models/multimodal/ofa/ofa.py create mode 100644 mmpretrain/models/multimodal/ofa/ofa_modules.py create mode 100644 mmpretrain/models/multimodal/otter/__init__.py create mode 100644 mmpretrain/models/multimodal/otter/otter.py create mode 100644 mmpretrain/models/multimodal/ram/__init__.py create mode 100644 mmpretrain/models/multimodal/ram/bert.py create mode 100644 mmpretrain/models/multimodal/ram/config/__init__.py create mode 100644 mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py create mode 100644 mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle create mode 100644 mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle create mode 100644 mmpretrain/models/multimodal/ram/data/ram_tag_list_threshold.pickle create mode 100644 mmpretrain/models/multimodal/ram/gradio_demo.py create mode 100644 mmpretrain/models/multimodal/ram/openset_utils.py create mode 100644 mmpretrain/models/multimodal/ram/ram.py create mode 100644 mmpretrain/models/multimodal/ram/run/__init__.py create mode 100644 mmpretrain/models/multimodal/ram/run/inference.py create mode 100644 mmpretrain/models/multimodal/ram/utils.py create mode 100644 mmpretrain/models/necks/__init__.py create mode 100644 mmpretrain/models/necks/beitv2_neck.py create mode 100644 mmpretrain/models/necks/cae_neck.py create mode 100644 mmpretrain/models/necks/densecl_neck.py create mode 100644 mmpretrain/models/necks/gap.py create mode 100644 mmpretrain/models/necks/gem.py create mode 100644 mmpretrain/models/necks/hr_fuse.py create mode 100644 mmpretrain/models/necks/itpn_neck.py create mode 100644 mmpretrain/models/necks/linear_neck.py create mode 100644 mmpretrain/models/necks/mae_neck.py create mode 100644 mmpretrain/models/necks/milan_neck.py create mode 100644 mmpretrain/models/necks/mixmim_neck.py create mode 100644 mmpretrain/models/necks/mocov2_neck.py create mode 100644 mmpretrain/models/necks/nonlinear_neck.py create mode 100644 mmpretrain/models/necks/simmim_neck.py create mode 100644 mmpretrain/models/necks/spark_neck.py create mode 100644 mmpretrain/models/necks/swav_neck.py create mode 100644 mmpretrain/models/peft/__init__.py create mode 100644 mmpretrain/models/peft/lora.py create mode 100644 mmpretrain/models/retrievers/__init__.py create mode 100644 mmpretrain/models/retrievers/base.py create mode 100644 mmpretrain/models/retrievers/image2image.py create mode 100644 mmpretrain/models/selfsup/__init__.py create mode 100644 mmpretrain/models/selfsup/barlowtwins.py create mode 100644 mmpretrain/models/selfsup/base.py create mode 100644 mmpretrain/models/selfsup/beit.py create mode 100644 mmpretrain/models/selfsup/byol.py create mode 100644 mmpretrain/models/selfsup/cae.py create mode 100644 mmpretrain/models/selfsup/densecl.py create mode 100644 mmpretrain/models/selfsup/eva.py create mode 100644 mmpretrain/models/selfsup/itpn.py create mode 100644 mmpretrain/models/selfsup/mae.py create mode 100644 mmpretrain/models/selfsup/maskfeat.py create mode 100644 mmpretrain/models/selfsup/mff.py create mode 100644 mmpretrain/models/selfsup/milan.py create mode 100644 mmpretrain/models/selfsup/mixmim.py create mode 100644 mmpretrain/models/selfsup/moco.py create mode 100644 mmpretrain/models/selfsup/mocov3.py create mode 100644 mmpretrain/models/selfsup/simclr.py create mode 100644 mmpretrain/models/selfsup/simmim.py create mode 100644 mmpretrain/models/selfsup/simsiam.py create mode 100644 mmpretrain/models/selfsup/spark.py create mode 100644 mmpretrain/models/selfsup/swav.py create mode 100644 mmpretrain/models/tta/__init__.py create mode 100644 mmpretrain/models/tta/score_tta.py create mode 100644 mmpretrain/models/utils/__init__.py create mode 100644 mmpretrain/models/utils/attention.py create mode 100644 mmpretrain/models/utils/batch_augments/__init__.py create mode 100644 mmpretrain/models/utils/batch_augments/cutmix.py create mode 100644 mmpretrain/models/utils/batch_augments/mixup.py create mode 100644 mmpretrain/models/utils/batch_augments/resizemix.py create mode 100644 mmpretrain/models/utils/batch_augments/wrapper.py create mode 100644 mmpretrain/models/utils/batch_shuffle.py create mode 100644 mmpretrain/models/utils/box_utils.py create mode 100644 mmpretrain/models/utils/channel_shuffle.py create mode 100644 mmpretrain/models/utils/clip_generator_helper.py create mode 100644 mmpretrain/models/utils/data_preprocessor.py create mode 100644 mmpretrain/models/utils/ema.py create mode 100644 mmpretrain/models/utils/embed.py create mode 100644 mmpretrain/models/utils/helpers.py create mode 100644 mmpretrain/models/utils/huggingface.py create mode 100644 mmpretrain/models/utils/inverted_residual.py create mode 100644 mmpretrain/models/utils/layer_scale.py create mode 100644 mmpretrain/models/utils/make_divisible.py create mode 100644 mmpretrain/models/utils/norm.py create mode 100644 mmpretrain/models/utils/position_encoding.py create mode 100644 mmpretrain/models/utils/res_layer_extra_norm.py create mode 100644 mmpretrain/models/utils/se_layer.py create mode 100644 mmpretrain/models/utils/sparse_modules.py create mode 100644 mmpretrain/models/utils/swiglu_ffn.py create mode 100644 mmpretrain/models/utils/tokenizer.py create mode 100644 mmpretrain/models/utils/vector_quantizer.py create mode 100644 mmpretrain/registry.py create mode 100644 mmpretrain/structures/__init__.py create mode 100644 mmpretrain/structures/data_sample.py create mode 100644 mmpretrain/structures/multi_task_data_sample.py create mode 100644 mmpretrain/structures/utils.py create mode 100644 mmpretrain/utils/__init__.py create mode 100644 mmpretrain/utils/analyze.py create mode 100644 mmpretrain/utils/collect_env.py create mode 100644 mmpretrain/utils/dependency.py create mode 100644 mmpretrain/utils/misc.py create mode 100644 mmpretrain/utils/progress.py create mode 100644 mmpretrain/utils/setup_env.py create mode 100644 mmpretrain/version.py create mode 100644 mmpretrain/visualization/__init__.py create mode 100644 mmpretrain/visualization/utils.py create mode 100644 mmpretrain/visualization/visualizer.py diff --git a/mmpretrain/__init__.py b/mmpretrain/__init__.py new file mode 100644 index 0000000..66866a8 --- /dev/null +++ b/mmpretrain/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import mmengine +from mmengine.utils import digit_version + +from .apis import * # noqa: F401, F403 +from .version import __version__ + +mmcv_minimum_version = '2.0.0' +mmcv_maximum_version = '2.4.0' +mmcv_version = digit_version(mmcv.__version__) + +mmengine_minimum_version = '0.8.3' +mmengine_maximum_version = '1.0.0' +mmengine_version = digit_version(mmengine.__version__) + +assert (mmcv_version >= digit_version(mmcv_minimum_version) + and mmcv_version < digit_version(mmcv_maximum_version)), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.' + +assert (mmengine_version >= digit_version(mmengine_minimum_version) + and mmengine_version < digit_version(mmengine_maximum_version)), \ + f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ + f'Please install mmengine>={mmengine_minimum_version}, ' \ + f'<{mmengine_maximum_version}.' + +__all__ = ['__version__'] diff --git a/mmpretrain/apis/__init__.py b/mmpretrain/apis/__init__.py new file mode 100644 index 0000000..6fbf443 --- /dev/null +++ b/mmpretrain/apis/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseInferencer +from .feature_extractor import FeatureExtractor +from .image_caption import ImageCaptionInferencer +from .image_classification import ImageClassificationInferencer +from .image_retrieval import ImageRetrievalInferencer +from .model import (ModelHub, get_model, inference_model, init_model, + list_models) +from .multimodal_retrieval import (ImageToTextRetrievalInferencer, + TextToImageRetrievalInferencer) +from .nlvr import NLVRInferencer +from .visual_grounding import VisualGroundingInferencer +from .visual_question_answering import VisualQuestionAnsweringInferencer + +__all__ = [ + 'init_model', 'inference_model', 'list_models', 'get_model', 'ModelHub', + 'ImageClassificationInferencer', 'ImageRetrievalInferencer', + 'FeatureExtractor', 'ImageCaptionInferencer', + 'TextToImageRetrievalInferencer', 'VisualGroundingInferencer', + 'VisualQuestionAnsweringInferencer', 'ImageToTextRetrievalInferencer', + 'BaseInferencer', 'NLVRInferencer' +] diff --git a/mmpretrain/apis/base.py b/mmpretrain/apis/base.py new file mode 100644 index 0000000..7bff6bd --- /dev/null +++ b/mmpretrain/apis/base.py @@ -0,0 +1,390 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from math import ceil +from typing import Callable, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.config import Config +from mmengine.dataset import default_collate +from mmengine.fileio import get_file_backend +from mmengine.model import BaseModel +from mmengine.runner import load_checkpoint + +from mmpretrain.structures import DataSample +from mmpretrain.utils import track +from .model import get_model, list_models + +ModelType = Union[BaseModel, str, Config] +InputType = Union[str, np.ndarray, list] + + +class BaseInferencer: + """Base inferencer for various tasks. + + The BaseInferencer provides the standard workflow for inference as follows: + + 1. Preprocess the input data by :meth:`preprocess`. + 2. Forward the data to the model by :meth:`forward`. ``BaseInferencer`` + assumes the model inherits from :class:`mmengine.models.BaseModel` and + will call `model.test_step` in :meth:`forward` by default. + 3. Visualize the results by :meth:`visualize`. + 4. Postprocess and return the results by :meth:`postprocess`. + + When we call the subclasses inherited from BaseInferencer (not overriding + ``__call__``), the workflow will be executed in order. + + All subclasses of BaseInferencer could define the following class + attributes for customization: + + - ``preprocess_kwargs``: The keys of the kwargs that will be passed to + :meth:`preprocess`. + - ``forward_kwargs``: The keys of the kwargs that will be passed to + :meth:`forward` + - ``visualize_kwargs``: The keys of the kwargs that will be passed to + :meth:`visualize` + - ``postprocess_kwargs``: The keys of the kwargs that will be passed to + :meth:`postprocess` + + All attributes mentioned above should be a ``set`` of keys (strings), + and each key should not be duplicated. Actually, :meth:`__call__` will + dispatch all the arguments to the corresponding methods according to the + ``xxx_kwargs`` mentioned above. + + Subclasses inherited from ``BaseInferencer`` should implement + :meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`: + + - _init_pipeline: Return a callable object to preprocess the input data. + - visualize: Visualize the results returned by :meth:`forward`. + - postprocess: Postprocess the results returned by :meth:`forward` and + :meth:`visualize`. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``cls.list_models()`` and you can also query it in + :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str | torch.device | None): Transfer the model to the target + device. Defaults to None. + device_map (str | dict | None): A map that specifies where each + submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every + submodule of it will be sent to the same device. You can use + `device_map="auto"` to automatically generate the device map. + Defaults to None. + offload_folder (str | None): If the `device_map` contains any value + `"disk"`, the folder where we will offload weights. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + """ + + preprocess_kwargs: set = set() + forward_kwargs: set = set() + visualize_kwargs: set = set() + postprocess_kwargs: set = set() + + def __init__(self, + model: ModelType, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + device_map=None, + offload_folder=None, + **kwargs) -> None: + + if isinstance(model, BaseModel): + if isinstance(pretrained, str): + load_checkpoint(model, pretrained, map_location='cpu') + if device_map is not None: + from .utils import dispatch_model + model = dispatch_model( + model, + device_map=device_map, + offload_folder=offload_folder) + elif device is not None: + model.to(device) + else: + model = get_model( + model, + pretrained, + device=device, + device_map=device_map, + offload_folder=offload_folder, + **kwargs) + + model.eval() + + self.config = model._config + self.model = model + self.pipeline = self._init_pipeline(self.config) + self.visualizer = None + + def __call__( + self, + inputs, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs, + ) -> dict: + """Call the inferencer. + + Args: + inputs (InputsType): Inputs for the inferencer. + return_datasamples (bool): Whether to return results as + :obj:`BaseDataElement`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + **kwargs: Key words arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + Returns: + dict: Inference and visualization results. + """ + ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) = self._dispatch_kwargs(**kwargs) + + ori_inputs = self._inputs_to_list(inputs) + inputs = self.preprocess( + ori_inputs, batch_size=batch_size, **preprocess_kwargs) + preds = [] + for data in track( + inputs, 'Inference', total=ceil(len(ori_inputs) / batch_size)): + preds.extend(self.forward(data, **forward_kwargs)) + visualization = self.visualize(ori_inputs, preds, **visualize_kwargs) + results = self.postprocess(preds, visualization, return_datasamples, + **postprocess_kwargs) + return results + + def _inputs_to_list(self, inputs: InputType) -> list: + """Preprocess the inputs to a list. + + Cast the input data to a list of data. + + - list or tuple: return inputs + - str: + - Directory path: return all files in the directory + - other cases: return a list containing the string. The string + could be a path to file, a url or other types of string according + to the task. + - other: return a list with one item. + + Args: + inputs (str | array | list): Inputs for the inferencer. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + if isinstance(inputs, str): + backend = get_file_backend(inputs) + if hasattr(backend, 'isdir') and backend.isdir(inputs): + # Backends like HttpsBackend do not implement `isdir`, so only + # those backends that implement `isdir` could accept the inputs + # as a directory + file_list = backend.list_dir_or_file(inputs, list_dir=False) + inputs = [ + backend.join_path(inputs, file) for file in file_list + ] + + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + return list(inputs) + + def preprocess(self, inputs: InputType, batch_size: int = 1, **kwargs): + """Process the inputs into a model-feedable format. + + Customize your preprocess by overriding this method. Preprocess should + return an iterable object, of which each item will be used as the + input of ``model.test_step``. + + ``BaseInferencer.preprocess`` will return an iterable chunked data, + which will be used in __call__ like this: + + .. code-block:: python + + def __call__(self, inputs, batch_size=1, **kwargs): + chunked_data = self.preprocess(inputs, batch_size, **kwargs) + for batch in chunked_data: + preds = self.forward(batch, **kwargs) + + Args: + inputs (InputsType): Inputs given by user. + batch_size (int): batch size. Defaults to 1. + + Yields: + Any: Data processed by the ``pipeline`` and ``default_collate``. + """ + chunked_data = self._get_chunk_data( + map(self.pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + @torch.no_grad() + def forward(self, inputs: Union[dict, tuple], **kwargs): + """Feed the inputs to the model.""" + return self.model.test_step(inputs) + + def visualize(self, + inputs: list, + preds: List[DataSample], + show: bool = False, + **kwargs) -> List[np.ndarray]: + """Visualize predictions. + + Customize your visualization by overriding this method. visualize + should return visualization results, which could be np.ndarray or any + other objects. + + Args: + inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. + preds (Any): Predictions of the model. + show (bool): Whether to display the image in a popup window. + Defaults to False. + + Returns: + List[np.ndarray]: Visualization results. + """ + if show: + raise NotImplementedError( + f'The `visualize` method of {self.__class__.__name__} ' + 'is not implemented.') + + @abstractmethod + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasample=False, + **kwargs, + ) -> dict: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Convert datasamples into a json-serializable dict if needed. + 2. Pack the predictions and visualization results and return them. + 3. Dump or log the predictions. + + Customize your postprocess by overriding this method. Make sure + ``postprocess`` will return a dict with visualization results and + inference results. + + Args: + preds (List[Dict]): Predictions of the model. + visualization (np.ndarray): Visualized predictions. + return_datasample (bool): Whether to return results as datasamples. + Defaults to False. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization`` + + - ``visualization (Any)``: Returned by :meth:`visualize` + - ``predictions`` (dict or DataSample): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasample=False``, it usually should be a + json-serializable dict containing only basic data elements such + as strings and numbers. + """ + + @abstractmethod + def _init_pipeline(self, cfg: Config) -> Callable: + """Initialize the test pipeline. + + Return a pipeline to handle various input data, such as ``str``, + ``np.ndarray``. It is an abstract method in BaseInferencer, and should + be implemented in subclasses. + + The returned pipeline will be used to process a single data. + It will be used in :meth:`preprocess` like this: + + .. code-block:: python + def preprocess(self, inputs, batch_size, **kwargs): + ... + dataset = map(self.pipeline, dataset) + ... + """ + + def _get_chunk_data(self, inputs: Iterable, chunk_size: int): + """Get batch data from dataset. + + Args: + inputs (Iterable): An iterable dataset. + chunk_size (int): Equivalent to batch size. + + Yields: + list: batch data. + """ + inputs_iter = iter(inputs) + while True: + try: + chunk_data = [] + for _ in range(chunk_size): + processed_data = next(inputs_iter) + chunk_data.append(processed_data) + yield chunk_data + except StopIteration: + if chunk_data: + yield chunk_data + break + + def _dispatch_kwargs(self, **kwargs) -> Tuple[dict, dict, dict, dict]: + """Dispatch kwargs to preprocess(), forward(), visualize() and + postprocess() according to the actual demands. + + Returns: + Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess, + forward, visualize and postprocess respectively. + """ + # Ensure each argument only matches one function + method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \ + self.visualize_kwargs | self.postprocess_kwargs + + union_kwargs = method_kwargs | set(kwargs.keys()) + if union_kwargs != method_kwargs: + unknown_kwargs = union_kwargs - method_kwargs + raise ValueError( + f'unknown argument {unknown_kwargs} for `preprocess`, ' + '`forward`, `visualize` and `postprocess`') + + preprocess_kwargs = {} + forward_kwargs = {} + visualize_kwargs = {} + postprocess_kwargs = {} + + for key, value in kwargs.items(): + if key in self.preprocess_kwargs: + preprocess_kwargs[key] = value + if key in self.forward_kwargs: + forward_kwargs[key] = value + if key in self.visualize_kwargs: + visualize_kwargs[key] = value + if key in self.postprocess_kwargs: + postprocess_kwargs[key] = value + + return ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List models defined in metafile of corresponding packages. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern) diff --git a/mmpretrain/apis/feature_extractor.py b/mmpretrain/apis/feature_extractor.py new file mode 100644 index 0000000..ee14f92 --- /dev/null +++ b/mmpretrain/apis/feature_extractor.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, List, Optional, Union + +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from .base import BaseInferencer, InputType +from .model import list_models + + +class FeatureExtractor(BaseInferencer): + """The inferencer for extract features. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``FeatureExtractor.list_models()`` and you can also query it in + :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import FeatureExtractor + >>> inferencer = FeatureExtractor('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3))) + >>> feats = inferencer('demo/demo.JPEG', stage='backbone')[0] + >>> for feat in feats: + >>> print(feat.shape) + torch.Size([256, 56, 56]) + torch.Size([512, 28, 28]) + torch.Size([1024, 14, 14]) + torch.Size([2048, 7, 7]) + """ # noqa: E501 + + def __call__(self, + inputs: InputType, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + batch_size (int): Batch size. Defaults to 1. + **kwargs: Other keyword arguments accepted by the `extract_feat` + method of the model. + + Returns: + tensor | Tuple[tensor]: The extracted features. + """ + ori_inputs = self._inputs_to_list(inputs) + inputs = self.preprocess(ori_inputs, batch_size=batch_size) + preds = [] + for data in inputs: + preds.extend(self.forward(data, **kwargs)) + + return preds + + @torch.no_grad() + def forward(self, inputs: Union[dict, tuple], **kwargs): + inputs = self.model.data_preprocessor(inputs, False)['inputs'] + outputs = self.model.extract_feat(inputs, **kwargs) + + def scatter(feats, index): + if isinstance(feats, torch.Tensor): + return feats[index] + else: + # Sequence of tensor + return type(feats)([scatter(item, index) for item in feats]) + + results = [] + for i in range(inputs.shape[0]): + results.append(scatter(outputs, i)) + + return results + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self): + raise NotImplementedError( + "The FeatureExtractor doesn't support visualization.") + + def postprocess(self): + raise NotImplementedError( + "The FeatureExtractor doesn't need postprocessing.") + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern) diff --git a/mmpretrain/apis/image_caption.py b/mmpretrain/apis/image_caption.py new file mode 100644 index 0000000..c11c0d3 --- /dev/null +++ b/mmpretrain/apis/image_caption.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional + +import numpy as np +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer, InputType +from .model import list_models + + +class ImageCaptionInferencer(BaseInferencer): + """The inferencer for image caption. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageCaptionInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import ImageCaptionInferencer + >>> inferencer = ImageCaptionInferencer('blip-base_3rdparty_caption') + >>> inferencer('demo/cat-dog.png')[0] + {'pred_caption': 'a puppy and a cat sitting on a blanket'} + """ # noqa: E501 + + visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'} + + def __call__(self, + images: InputType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + images (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(images, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[InputType], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_image_caption( + image, + data_sample, + resize=resize, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + results.append({'pred_caption': data_sample.get('pred_caption')}) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image Caption') diff --git a/mmpretrain/apis/image_classification.py b/mmpretrain/apis/image_classification.py new file mode 100644 index 0000000..a202180 --- /dev/null +++ b/mmpretrain/apis/image_classification.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer, InputType, ModelType +from .model import list_models + + +class ImageClassificationInferencer(BaseInferencer): + """The inferencer for image classification. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageClassificationInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + 1. Use a pre-trained model in MMPreTrain to inference an image. + + >>> from mmpretrain import ImageClassificationInferencer + >>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k') + >>> inferencer('demo/demo.JPEG') + [{'pred_score': array([...]), + 'pred_label': 65, + 'pred_score': 0.6649367809295654, + 'pred_class': 'sea snake'}] + + 2. Use a config file and checkpoint to inference multiple images on GPU, + and save the visualization results in a folder. + + >>> from mmpretrain import ImageClassificationInferencer + >>> inferencer = ImageClassificationInferencer( + model='configs/resnet/resnet50_8xb32_in1k.py', + pretrained='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth', + device='cuda') + >>> inferencer(['demo/dog.jpg', 'demo/bird.JPEG'], show_dir="./visualize/") + """ # noqa: E501 + + visualize_kwargs: set = { + 'resize', 'rescale_factor', 'draw_score', 'show', 'show_dir', + 'wait_time' + } + + def __init__(self, + model: ModelType, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + classes=None, + **kwargs) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + if classes is not None: + self.classes = classes + else: + self.classes = getattr(self.model, '_dataset_meta', + {}).get('classes') + + def __call__(self, + inputs: InputType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + rescale_factor (float, optional): Rescale the image by the rescale + factor for visualization. This is helpful when the image is too + large or too small for visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__( + inputs, + return_datasamples=return_datasamples, + batch_size=batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[InputType], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + rescale_factor: Optional[float] = None, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_cls( + image, + data_sample, + classes=self.classes, + resize=resize, + show=show, + wait_time=wait_time, + rescale_factor=rescale_factor, + draw_gt=False, + draw_pred=True, + draw_score=draw_score, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + pred_scores = data_sample.pred_score + pred_score = float(torch.max(pred_scores).item()) + pred_label = torch.argmax(pred_scores).item() + result = { + 'pred_scores': pred_scores.detach().cpu().numpy(), + 'pred_label': pred_label, + 'pred_score': pred_score, + } + if self.classes is not None: + result['pred_class'] = self.classes[pred_label] + results.append(result) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image Classification') diff --git a/mmpretrain/apis/image_retrieval.py b/mmpretrain/apis/image_retrieval.py new file mode 100644 index 0000000..27919b2 --- /dev/null +++ b/mmpretrain/apis/image_retrieval.py @@ -0,0 +1,288 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import BaseDataset, Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer, InputType, ModelType +from .model import list_models + + +class ImageRetrievalInferencer(BaseInferencer): + """The inferencer for image to image retrieval. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageRetrievalInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + prototype (str | list | dict | DataLoader, BaseDataset): The images to + be retrieved. It can be the following types: + + - str: The directory of the the images. + - list: A list of path of the images. + - dict: A config dict of the a prototype dataset. + - BaseDataset: A prototype dataset. + - DataLoader: A data loader to load the prototype data. + + prototype_cache (str, optional): The path of the generated prototype + features. If exists, directly load the cache instead of re-generate + the prototype features. If not exists, save the generated features + to the path. Defaults to None. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import ImageRetrievalInferencer + >>> inferencer = ImageRetrievalInferencer( + ... 'resnet50-arcface_inshop', + ... prototype='./demo/', + ... prototype_cache='img_retri.pth') + >>> inferencer('demo/cat-dog.png', topk=2)[0][1] + {'match_score': tensor(0.4088, device='cuda:0'), + 'sample_idx': 3, + 'sample': {'img_path': './demo/dog.jpg'}} + """ # noqa: E501 + + visualize_kwargs: set = { + 'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk' + } + postprocess_kwargs: set = {'topk'} + + def __init__( + self, + model: ModelType, + prototype, + prototype_cache=None, + prepare_batch_size=8, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + **kwargs, + ) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + self.prototype_dataset = self._prepare_prototype( + prototype, prototype_cache, prepare_batch_size) + + def _prepare_prototype(self, prototype, cache=None, batch_size=8): + from mmengine.dataset import DefaultSampler + from torch.utils.data import DataLoader + + def build_dataloader(dataset): + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=default_collate, + sampler=DefaultSampler(dataset, shuffle=False), + persistent_workers=False, + ) + + if isinstance(prototype, str): + # A directory path of images + prototype = dict( + type='CustomDataset', with_label=False, data_root=prototype) + + if isinstance(prototype, list): + test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline] + dataset = BaseDataset( + lazy_init=True, serialize_data=False, pipeline=test_pipeline) + dataset.data_list = [{ + 'sample_idx': i, + 'img_path': file + } for i, file in enumerate(prototype)] + dataset._fully_initialized = True + dataloader = build_dataloader(dataset) + elif isinstance(prototype, dict): + # A config of dataset + from mmpretrain.registry import DATASETS + test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline] + prototype.setdefault('pipeline', test_pipeline) + dataset = DATASETS.build(prototype) + dataloader = build_dataloader(dataset) + elif isinstance(prototype, DataLoader): + dataset = prototype.dataset + dataloader = prototype + elif isinstance(prototype, BaseDataset): + dataset = prototype + dataloader = build_dataloader(dataset) + else: + raise TypeError(f'Unsupported prototype type {type(prototype)}.') + + if cache is not None and Path(cache).exists(): + self.model.prototype = cache + else: + self.model.prototype = dataloader + self.model.prepare_prototype() + + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + if cache is None: + logger.info('The prototype has been prepared, you can use ' + '`save_prototype` to dump it into a pickle ' + 'file for the future usage.') + elif not Path(cache).exists(): + self.save_prototype(cache) + logger.info(f'The prototype has been saved at {cache}.') + + return dataset + + def save_prototype(self, path): + self.model.dump_prototype(path) + + def __call__(self, + inputs: InputType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the match scores. + Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[InputType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[InputType], + preds: List[DataSample], + topk: int = 3, + resize: Optional[int] = 224, + show: bool = False, + wait_time: int = 0, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_image_retrieval( + image, + data_sample, + self.prototype_dataset, + topk=topk, + resize=resize, + draw_score=draw_score, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False, + topk=1, + ) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + matches = [] + for match_score, sample_idx in zip(match_scores, indices): + sample = self.prototype_dataset.get_data_info( + sample_idx.item()) + sample_idx = sample.pop('sample_idx') + matches.append({ + 'match_score': match_score, + 'sample_idx': sample_idx, + 'sample': sample + }) + results.append(matches) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image Retrieval') diff --git a/mmpretrain/apis/model.py b/mmpretrain/apis/model.py new file mode 100644 index 0000000..eba475e --- /dev/null +++ b/mmpretrain/apis/model.py @@ -0,0 +1,408 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import fnmatch +import os.path as osp +import re +import warnings +from os import PathLike +from pathlib import Path +from typing import List, Tuple, Union + +from mmengine.config import Config +from modelindex.load_model_index import load +from modelindex.models.Model import Model + + +class ModelHub: + """A hub to host the meta information of all pre-defined models.""" + _models_dict = {} + __mmpretrain_registered = False + + @classmethod + def register_model_index(cls, + model_index_path: Union[str, PathLike], + config_prefix: Union[str, PathLike, None] = None): + """Parse the model-index file and register all models. + + Args: + model_index_path (str | PathLike): The path of the model-index + file. + config_prefix (str | PathLike | None): The prefix of all config + file paths in the model-index file. + """ + model_index = load(str(model_index_path)) + model_index.build_models_with_collections() + + for metainfo in model_index.models: + model_name = metainfo.name.lower() + if metainfo.name in cls._models_dict: + raise ValueError( + 'The model name {} is conflict in {} and {}.'.format( + model_name, osp.abspath(metainfo.filepath), + osp.abspath(cls._models_dict[model_name].filepath))) + metainfo.config = cls._expand_config_path(metainfo, config_prefix) + cls._models_dict[model_name] = metainfo + + @classmethod + def get(cls, model_name): + """Get the model's metainfo by the model name. + + Args: + model_name (str): The name of model. + + Returns: + modelindex.models.Model: The metainfo of the specified model. + """ + cls._register_mmpretrain_models() + # lazy load config + metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower())) + if metainfo is None: + raise ValueError( + f'Failed to find model "{model_name}". please use ' + '`mmpretrain.list_models` to get all available names.') + if isinstance(metainfo.config, str): + metainfo.config = Config.fromfile(metainfo.config) + return metainfo + + @staticmethod + def _expand_config_path(metainfo: Model, + config_prefix: Union[str, PathLike] = None): + if config_prefix is None: + config_prefix = osp.dirname(metainfo.filepath) + + if metainfo.config is None or osp.isabs(metainfo.config): + config_path: str = metainfo.config + else: + config_path = osp.abspath(osp.join(config_prefix, metainfo.config)) + + return config_path + + @classmethod + def _register_mmpretrain_models(cls): + # register models in mmpretrain + if not cls.__mmpretrain_registered: + from importlib_metadata import distribution + root = distribution('mmpretrain').locate_file('mmpretrain') + model_index_path = root / '.mim' / 'model-index.yml' + ModelHub.register_model_index( + model_index_path, config_prefix=root / '.mim') + cls.__mmpretrain_registered = True + + @classmethod + def has(cls, model_name): + """Whether a model name is in the ModelHub.""" + return model_name in cls._models_dict + + +def get_model(model: Union[str, Config], + pretrained: Union[str, bool] = False, + device=None, + device_map=None, + offload_folder=None, + url_mapping: Tuple[str, str] = None, + **kwargs): + """Get a pre-defined model or create a model from config. + + Args: + model (str | Config): The name of model, the config file path or a + config instance. + pretrained (bool | str): When use name to specify model, you can + use ``True`` to load the pre-defined pretrained weights. And you + can also use a string to specify the path or link of weights to + load. Defaults to False. + device (str | torch.device | None): Transfer the model to the target + device. Defaults to None. + device_map (str | dict | None): A map that specifies where each + submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every + submodule of it will be sent to the same device. You can use + `device_map="auto"` to automatically generate the device map. + Defaults to None. + offload_folder (str | None): If the `device_map` contains any value + `"disk"`, the folder where we will offload weights. + url_mapping (Tuple[str, str], optional): The mapping of pretrained + checkpoint link. For example, load checkpoint from a local dir + instead of download by ``('https://.*/', './checkpoint')``. + Defaults to None. + **kwargs: Other keyword arguments of the model config. + + Returns: + mmengine.model.BaseModel: The result model. + + Examples: + Get a ResNet-50 model and extract images feature: + + >>> import torch + >>> from mmpretrain import get_model + >>> inputs = torch.rand(16, 3, 224, 224) + >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3))) + >>> feats = model.extract_feat(inputs) + >>> for feat in feats: + ... print(feat.shape) + torch.Size([16, 256]) + torch.Size([16, 512]) + torch.Size([16, 1024]) + torch.Size([16, 2048]) + + Get Swin-Transformer model with pre-trained weights and inference: + + >>> from mmpretrain import get_model, inference_model + >>> model = get_model('swin-base_16xb64_in1k', pretrained=True) + >>> result = inference_model(model, 'demo/demo.JPEG') + >>> print(result['pred_class']) + 'sea snake' + """ # noqa: E501 + if device_map is not None: + from .utils import dispatch_model + dispatch_model._verify_require() + + metainfo = None + if isinstance(model, Config): + config = copy.deepcopy(model) + if pretrained is True and 'load_from' in config: + pretrained = config.load_from + elif isinstance(model, (str, PathLike)) and Path(model).suffix == '.py': + config = Config.fromfile(model) + if pretrained is True and 'load_from' in config: + pretrained = config.load_from + elif isinstance(model, str): + metainfo = ModelHub.get(model) + config = metainfo.config + if pretrained is True and metainfo.weights is not None: + pretrained = metainfo.weights + else: + raise TypeError('model must be a name, a path or a Config object, ' + f'but got {type(config)}') + + if pretrained is True: + warnings.warn('Unable to find pre-defined checkpoint of the model.') + pretrained = None + elif pretrained is False: + pretrained = None + + if kwargs: + config.merge_from_dict({'model': kwargs}) + config.model.setdefault('data_preprocessor', + config.get('data_preprocessor', None)) + + from mmengine.registry import DefaultScope + + from mmpretrain.registry import MODELS + with DefaultScope.overwrite_default_scope('mmpretrain'): + model = MODELS.build(config.model) + + dataset_meta = {} + if pretrained: + # Mapping the weights to GPU may cause unexpected video memory leak + # which refers to https://github.com/open-mmlab/mmdetection/pull/6405 + from mmengine.runner import load_checkpoint + if url_mapping is not None: + pretrained = re.sub(url_mapping[0], url_mapping[1], pretrained) + checkpoint = load_checkpoint(model, pretrained, map_location='cpu') + if 'dataset_meta' in checkpoint.get('meta', {}): + # mmpretrain 1.x + dataset_meta = checkpoint['meta']['dataset_meta'] + elif 'CLASSES' in checkpoint.get('meta', {}): + # mmcls 0.x + dataset_meta = {'classes': checkpoint['meta']['CLASSES']} + + if len(dataset_meta) == 0 and 'test_dataloader' in config: + from mmpretrain.registry import DATASETS + dataset_class = DATASETS.get(config.test_dataloader.dataset.type) + dataset_meta = getattr(dataset_class, 'METAINFO', {}) + + if device_map is not None: + model = dispatch_model( + model, device_map=device_map, offload_folder=offload_folder) + elif device is not None: + model.to(device) + + model._dataset_meta = dataset_meta # save the dataset meta + model._config = config # save the config in the model + model._metainfo = metainfo # save the metainfo in the model + model.eval() + return model + + +def init_model(config, checkpoint=None, device=None, **kwargs): + """Initialize a classifier from config file (deprecated). + + It's only for compatibility, please use :func:`get_model` instead. + + Args: + config (str | :obj:`mmengine.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str | torch.device | None): Transfer the model to the target + device. Defaults to None. + **kwargs: Other keyword arguments of the model config. + + Returns: + nn.Module: The constructed model. + """ + return get_model(config, checkpoint, device, **kwargs) + + +def list_models(pattern=None, exclude_patterns=None, task=None) -> List[str]: + """List all models available in MMPretrain. + + Args: + pattern (str | None): A wildcard pattern to match model names. + Defaults to None. + exclude_patterns (list | None): A list of wildcard patterns to + exclude names from the matched names. Defaults to None. + task (str | none): The evaluation task of the model. + + Returns: + List[str]: a list of model names. + + Examples: + List all models: + + >>> from mmpretrain import list_models + >>> list_models() + + List ResNet-50 models on ImageNet-1k dataset: + + >>> from mmpretrain import list_models + >>> list_models('resnet*in1k') + ['resnet50_8xb32_in1k', + 'resnet50_8xb32-fp16_in1k', + 'resnet50_8xb256-rsb-a1-600e_in1k', + 'resnet50_8xb256-rsb-a2-300e_in1k', + 'resnet50_8xb256-rsb-a3-100e_in1k'] + + List Swin-Transformer models trained from stratch and exclude + Swin-Transformer-V2 models: + + >>> from mmpretrain import list_models + >>> list_models('swin', exclude_patterns=['swinv2', '*-pre']) + ['swin-base_16xb64_in1k', + 'swin-base_3rdparty_in1k', + 'swin-base_3rdparty_in1k-384', + 'swin-large_8xb8_cub-384px', + 'swin-small_16xb64_in1k', + 'swin-small_3rdparty_in1k', + 'swin-tiny_16xb64_in1k', + 'swin-tiny_3rdparty_in1k'] + + List all EVA models for image classification task. + + >>> from mmpretrain import list_models + >>> list_models('eva', task='Image Classification') + ['eva-g-p14_30m-in21k-pre_3rdparty_in1k-336px', + 'eva-g-p14_30m-in21k-pre_3rdparty_in1k-560px', + 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-196px', + 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-336px', + 'eva-l-p14_mim-pre_3rdparty_in1k-196px', + 'eva-l-p14_mim-pre_3rdparty_in1k-336px'] + """ + ModelHub._register_mmpretrain_models() + matches = set(ModelHub._models_dict.keys()) + + if pattern is not None: + # Always match keys with any postfix. + matches = set(fnmatch.filter(matches, pattern + '*')) + + exclude_patterns = exclude_patterns or [] + for exclude_pattern in exclude_patterns: + exclude = set(fnmatch.filter(matches, exclude_pattern + '*')) + matches = matches - exclude + + if task is not None: + task_matches = [] + for key in matches: + metainfo = ModelHub._models_dict[key] + if metainfo.results is None and task == 'null': + task_matches.append(key) + elif metainfo.results is None: + continue + elif task in [result.task for result in metainfo.results]: + task_matches.append(key) + matches = task_matches + + return sorted(list(matches)) + + +def inference_model(model, *args, **kwargs): + """Inference an image with the inferencer. + + Automatically select inferencer to inference according to the type of + model. It's a shortcut for a quick start, and for advanced usage, please + use the correspondding inferencer class. + + Here is the mapping from task to inferencer: + + - Image Classification: :class:`ImageClassificationInferencer` + - Image Retrieval: :class:`ImageRetrievalInferencer` + - Image Caption: :class:`ImageCaptionInferencer` + - Visual Question Answering: :class:`VisualQuestionAnsweringInferencer` + - Visual Grounding: :class:`VisualGroundingInferencer` + - Text-To-Image Retrieval: :class:`TextToImageRetrievalInferencer` + - Image-To-Text Retrieval: :class:`ImageToTextRetrievalInferencer` + - NLVR: :class:`NLVRInferencer` + + Args: + model (BaseModel | str | Config): The loaded model, the model + name or the config of the model. + *args: Positional arguments to call the inferencer. + **kwargs: Other keyword arguments to initialize and call the + correspondding inferencer. + + Returns: + result (dict): The inference results. + """ # noqa: E501 + from mmengine.model import BaseModel + + if isinstance(model, BaseModel): + metainfo = getattr(model, '_metainfo', None) + else: + metainfo = ModelHub.get(model) + + from inspect import signature + + from .image_caption import ImageCaptionInferencer + from .image_classification import ImageClassificationInferencer + from .image_retrieval import ImageRetrievalInferencer + from .multimodal_retrieval import (ImageToTextRetrievalInferencer, + TextToImageRetrievalInferencer) + from .nlvr import NLVRInferencer + from .visual_grounding import VisualGroundingInferencer + from .visual_question_answering import VisualQuestionAnsweringInferencer + task_mapping = { + 'Image Classification': ImageClassificationInferencer, + 'Image Retrieval': ImageRetrievalInferencer, + 'Image Caption': ImageCaptionInferencer, + 'Visual Question Answering': VisualQuestionAnsweringInferencer, + 'Visual Grounding': VisualGroundingInferencer, + 'Text-To-Image Retrieval': TextToImageRetrievalInferencer, + 'Image-To-Text Retrieval': ImageToTextRetrievalInferencer, + 'NLVR': NLVRInferencer, + } + + inferencer_type = None + + if metainfo is not None and metainfo.results is not None: + tasks = set(result.task for result in metainfo.results) + inferencer_type = [ + task_mapping.get(task) for task in tasks if task in task_mapping + ] + if len(inferencer_type) > 1: + inferencer_names = [cls.__name__ for cls in inferencer_type] + warnings.warn('The model supports multiple tasks, auto select ' + f'{inferencer_names[0]}, you can also use other ' + f'inferencer {inferencer_names} directly.') + inferencer_type = inferencer_type[0] + + if inferencer_type is None: + raise NotImplementedError('No available inferencer for the model') + + init_kwargs = { + k: kwargs.pop(k) + for k in list(kwargs) + if k in signature(inferencer_type).parameters.keys() + } + + inferencer = inferencer_type(model, **init_kwargs) + return inferencer(*args, **kwargs)[0] diff --git a/mmpretrain/apis/multimodal_retrieval.py b/mmpretrain/apis/multimodal_retrieval.py new file mode 100644 index 0000000..5eb9c85 --- /dev/null +++ b/mmpretrain/apis/multimodal_retrieval.py @@ -0,0 +1,603 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from pathlib import Path +from typing import Callable, List, Optional, Tuple, Union + +import mmengine +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import BaseDataset, Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from mmpretrain.utils import track +from .base import BaseInferencer +from .base import InputType as ImageType +from .base import ModelType +from .model import list_models + + +def filter_transforms(transforms: list, data_info: dict): + """Filter pipeline to avoid KeyError with partial data info.""" + data_info = deepcopy(data_info) + filtered_transforms = [] + for t in transforms: + try: + data_info = t(data_info) + filtered_transforms.append(t) + except KeyError: + pass + return filtered_transforms + + +class TextToImageRetrievalInferencer(BaseInferencer): + """The inferencer for text to image retrieval. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``TextToImageRetrievalInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + prototype (str | list | dict | DataLoader | BaseDataset): The images to + be retrieved. It can be the following types: + + - str: The directory of the the images. + - list: A list of path of the images. + - dict: A config dict of the a prototype dataset. + - BaseDataset: A prototype dataset. + - DataLoader: A data loader to load the prototype data. + + prototype_cache (str, optional): The path of the generated prototype + features. If exists, directly load the cache instead of re-generate + the prototype features. If not exists, save the generated features + to the path. Defaults to None. + fast_match (bool): Some algorithms will record extra image features for + further matching, which may consume large memory, set True to avoid + this behavior. Defaults to True. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import TextToImageRetrievalInferencer + >>> inferencer = TextToImageRetrievalInferencer( + ... 'blip-base_3rdparty_retrieval', + ... prototype='./demo/', + ... prototype_cache='t2i_retri.pth') + >>> inferencer('A cat and a dog.')[0] + {'match_score': tensor(0.3855, device='cuda:0'), + 'sample_idx': 1, + 'sample': {'img_path': './demo/cat-dog.png'}} + """ # noqa: E501 + + visualize_kwargs: set = { + 'draw_score', 'show_dir', 'show', 'wait_time', 'figsize', 'topk' + } + postprocess_kwargs: set = {'topk'} + + def __init__(self, + model: ModelType, + prototype, + prototype_cache=None, + fast_match=True, + prepare_batch_size=8, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + **kwargs) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + self.img_pipeline, self.text_pipeline = self.pipeline + + if hasattr(self.model, 'fast_match'): + self.model.fast_match = fast_match + + self.prototype_dataset = self._prepare_prototype( + prototype, prototype_cache, batch_size=prepare_batch_size) + + def _prepare_prototype(self, prototype, cache=None, batch_size=8): + from mmengine.dataset import DefaultSampler + from torch.utils.data import DataLoader + + def build_dataloader(dataset): + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=default_collate, + sampler=DefaultSampler(dataset, shuffle=False), + persistent_workers=False, + ) + + if isinstance(prototype, str): + # A directory path of images + prototype = dict( + type='CustomDataset', with_label=False, data_root=prototype) + + if isinstance(prototype, list): + test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline] + dataset = BaseDataset( + lazy_init=True, serialize_data=False, pipeline=test_pipeline) + dataset.data_list = [{ + 'sample_idx': i, + 'img_path': file + } for i, file in enumerate(prototype)] + dataset._fully_initialized = True + dataloader = build_dataloader(dataset) + elif isinstance(prototype, dict): + # A config of dataset + from mmpretrain.registry import DATASETS + test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline] + prototype.setdefault('pipeline', test_pipeline) + dataset = DATASETS.build(prototype) + dataloader = build_dataloader(dataset) + elif isinstance(prototype, list): + test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline] + dataset = BaseDataset( + lazy_init=True, serialize_data=False, pipeline=test_pipeline) + dataset.data_list = [{ + 'sample_idx': i, + 'img_path': file + } for i, file in enumerate(prototype)] + dataset._fully_initialized = True + dataloader = build_dataloader(dataset) + elif isinstance(prototype, DataLoader): + dataset = prototype.dataset + dataloader = prototype + elif isinstance(prototype, BaseDataset): + dataset = prototype + dataloader = build_dataloader(dataset) + else: + raise TypeError(f'Unsupported prototype type {type(prototype)}.') + + if cache is not None and Path(cache).exists(): + self.prototype = torch.load(cache) + else: + prototype = [] + for data_batch in track(dataloader, 'Prepare prototype...'): + with torch.no_grad(): + data_batch = self.model.data_preprocessor( + data_batch, False) + feats = self.model._run_forward(data_batch, mode='tensor') + prototype.append(feats) + prototype = { + k: torch.cat([d[k] for d in prototype]) + for k in prototype[0] + } + self.prototype = prototype + + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + if cache is None: + logger.info('The prototype has been prepared, you can use ' + '`save_prototype` to dump it into a pickle ' + 'file for the future usage.') + elif not Path(cache).exists(): + self.save_prototype(cache) + logger.info(f'The prototype has been saved at {cache}.') + + return dataset + + def save_prototype(self, path): + torch.save(self.prototype, path) + + def __call__(self, + inputs: ImageType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the match scores. + Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + @torch.no_grad() + def forward(self, data: dict, **kwargs): + """Feed the inputs to the model.""" + data = self.model.data_preprocessor(data, False) + data_samples = data['data_samples'] + feats = self.prototype.copy() + feats.update(self.model.extract_feat(data_samples=data_samples)) + return self.model.predict_all(feats, data_samples, cal_i2t=False)[0] + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg] + img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)} + text_info = {'text': 'example'} + img_pipeline = Compose(filter_transforms(test_transfroms, img_info)) + text_pipeline = Compose(filter_transforms(test_transfroms, text_info)) + return img_pipeline, text_pipeline + + def preprocess(self, inputs: List[str], batch_size: int = 1): + + def process_text(input_: str): + return self.text_pipeline({'text': input_}) + + chunked_data = self._get_chunk_data( + map(process_text, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[str], + preds: List[DataSample], + topk: int = 3, + figsize: Tuple[int, int] = (16, 9), + show: bool = False, + wait_time: int = 0, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (text, data_sample) in enumerate(zip(ori_inputs, preds)): + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_t2i_retrieval( + text, + data_sample, + self.prototype_dataset, + topk=topk, + fig_cfg=dict(figsize=figsize), + draw_score=draw_score, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False, + topk=1, + ) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + matches = [] + for match_score, sample_idx in zip(match_scores, indices): + sample = self.prototype_dataset.get_data_info( + sample_idx.item()) + sample_idx = sample.pop('sample_idx') + matches.append({ + 'match_score': match_score, + 'sample_idx': sample_idx, + 'sample': sample + }) + results.append(matches) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Text-To-Image Retrieval') + + +class ImageToTextRetrievalInferencer(BaseInferencer): + """The inferencer for image to text retrieval. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``ImageToTextRetrievalInferencer.list_models()`` and you can + also query it in :doc:`/modelzoo_statistics`. + prototype (str | list | dict | DataLoader, BaseDataset): The images to + be retrieved. It can be the following types: + + - str: The file path to load the string list. + - list: A list of string. + + prototype_cache (str, optional): The path of the generated prototype + features. If exists, directly load the cache instead of re-generate + the prototype features. If not exists, save the generated features + to the path. Defaults to None. + fast_match (bool): Some algorithms will record extra image features for + further matching, which may consume large memory, set True to avoid + this behavior. Defaults to True. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import ImageToTextRetrievalInferencer + >>> inferencer = ImageToTextRetrievalInferencer( + ... 'blip-base_3rdparty_retrieval', + ... prototype=['cat', 'dog', 'snake', 'bird'], + ... prototype_cache='i2t_retri.pth') + >>> inferencer('demo/bird.JPEG')[0] + {'match_score': tensor(0.3855, device='cuda:0'), + 'sample_idx': 1, + 'sample': {'img_path': './demo/cat-dog.png'}} + """ # noqa: E501 + + visualize_kwargs: set = { + 'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk' + } + postprocess_kwargs: set = {'topk'} + + def __init__(self, + model: ModelType, + prototype, + prototype_cache=None, + fast_match=True, + prepare_batch_size=8, + pretrained: Union[bool, str] = True, + device: Union[str, torch.device, None] = None, + **kwargs) -> None: + super().__init__( + model=model, pretrained=pretrained, device=device, **kwargs) + + self.img_pipeline, self.text_pipeline = self.pipeline + + if hasattr(self.model, 'fast_match'): + self.model.fast_match = fast_match + + self.prototype_dataset = self._prepare_prototype( + prototype, cache=prototype_cache, batch_size=prepare_batch_size) + + def _prepare_prototype(self, prototype, cache=None, batch_size=8): + from mmengine.dataset import DefaultSampler + from torch.utils.data import DataLoader + + def build_dataloader(dataset): + return DataLoader( + [ + self.text_pipeline({ + 'sample_idx': i, + 'text': text + }) for i, text in enumerate(dataset) + ], + batch_size=batch_size, + collate_fn=default_collate, + sampler=DefaultSampler(dataset, shuffle=False), + persistent_workers=False, + ) + + if isinstance(prototype, str): + # A file path of a list of string + dataset = mmengine.list_from_file(prototype) + elif mmengine.utils.is_seq_of(prototype, str): + dataset = prototype + else: + raise TypeError(f'Unsupported prototype type {type(prototype)}.') + + dataloader = build_dataloader(dataset) + + if cache is not None and Path(cache).exists(): + self.prototype = torch.load(cache) + else: + prototype = [] + for data_batch in track(dataloader, 'Prepare prototype...'): + with torch.no_grad(): + data_batch = self.model.data_preprocessor( + data_batch, False) + feats = self.model._run_forward(data_batch, mode='tensor') + prototype.append(feats) + prototype = { + k: torch.cat([d[k] for d in prototype]) + for k in prototype[0] + } + self.prototype = prototype + + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + if cache is None: + logger.info('The prototype has been prepared, you can use ' + '`save_prototype` to dump it into a pickle ' + 'file for the future usage.') + elif not Path(cache).exists(): + self.save_prototype(cache) + logger.info(f'The prototype has been saved at {cache}.') + + return dataset + + def save_prototype(self, path): + torch.save(self.prototype, path) + + def __call__(self, + inputs: ImageType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (str | array | list): The image path or array, or a list of + images. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the match scores. + Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + @torch.no_grad() + def forward(self, data: dict, **kwargs): + """Feed the inputs to the model.""" + data = self.model.data_preprocessor(data, False) + feats = self.prototype.copy() + feats.update(self.model.extract_feat(images=data['images'])) + return self.model.predict_all( + feats, data['data_samples'], cal_t2i=False)[0] + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg] + img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)} + text_info = {'text': 'example'} + img_pipeline = Compose(filter_transforms(test_transfroms, img_info)) + text_pipeline = Compose(filter_transforms(test_transfroms, text_info)) + return img_pipeline, text_pipeline + + def preprocess(self, inputs: List[ImageType], batch_size: int = 1): + + def load_image(input_): + img = imread(input_) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return dict( + img=img, + img_shape=img.shape[:2], + ori_shape=img.shape[:2], + ) + + pipeline = Compose([load_image, self.img_pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[ImageType], + preds: List[DataSample], + topk: int = 3, + resize: Optional[int] = 224, + show: bool = False, + wait_time: int = 0, + draw_score=True, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_i2t_retrieval( + image, + data_sample, + self.prototype_dataset, + topk=topk, + resize=resize, + draw_score=draw_score, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess( + self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False, + topk=1, + ) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + matches = [] + for match_score, sample_idx in zip(match_scores, indices): + text = self.prototype_dataset[sample_idx.item()] + matches.append({ + 'match_score': match_score, + 'sample_idx': sample_idx, + 'text': text + }) + results.append(matches) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Image-To-Text Retrieval') diff --git a/mmpretrain/apis/nlvr.py b/mmpretrain/apis/nlvr.py new file mode 100644 index 0000000..9977c3b --- /dev/null +++ b/mmpretrain/apis/nlvr.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer +from .model import list_models + +InputType = Tuple[Union[str, np.ndarray], Union[str, np.ndarray], str] +InputsType = Union[List[InputType], InputType] + + +class NLVRInferencer(BaseInferencer): + """The inferencer for Natural Language for Visual Reasoning. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``NLVRInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + """ + + visualize_kwargs: set = { + 'resize', 'draw_score', 'show', 'show_dir', 'wait_time' + } + + def __call__(self, + inputs: InputsType, + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (tuple, List[tuple]): The input data tuples, every tuple + should include three items (left image, right image, text). + The image can be a path or numpy array. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + assert isinstance(inputs, (tuple, list)) + if isinstance(inputs, tuple): + inputs = [inputs] + for input_ in inputs: + assert isinstance(input_, tuple) + assert len(input_) == 3 + + return super().__call__( + inputs, + return_datasamples=return_datasamples, + batch_size=batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + assert test_pipeline_cfg[0]['type'] == 'ApplyToList' + + list_pipeline = deepcopy(test_pipeline_cfg[0]) + if list_pipeline.scatter_key == 'img_path': + # Remove `LoadImageFromFile` + list_pipeline.transforms.pop(0) + list_pipeline.scatter_key = 'img' + + test_pipeline = Compose( + [TRANSFORMS.build(list_pipeline)] + + [TRANSFORMS.build(t) for t in test_pipeline_cfg[1:]]) + return test_pipeline + + def preprocess(self, inputs: InputsType, batch_size: int = 1): + + def load_image(input_): + img1 = imread(input_[0]) + img2 = imread(input_[1]) + text = input_[2] + if img1 is None: + raise ValueError(f'Failed to read image {input_[0]}.') + if img2 is None: + raise ValueError(f'Failed to read image {input_[1]}.') + return dict( + img=[img1, img2], + img_shape=[img1.shape[:2], img2.shape[:2]], + ori_shape=[img1.shape[:2], img2.shape[:2]], + text=text, + ) + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + pred_scores = data_sample.pred_score + pred_score = float(torch.max(pred_scores).item()) + pred_label = torch.argmax(pred_scores).item() + result = { + 'pred_scores': pred_scores.detach().cpu().numpy(), + 'pred_label': pred_label, + 'pred_score': pred_score, + } + results.append(result) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='NLVR') diff --git a/mmpretrain/apis/utils.py b/mmpretrain/apis/utils.py new file mode 100644 index 0000000..83e763254 --- /dev/null +++ b/mmpretrain/apis/utils.py @@ -0,0 +1,270 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from collections import defaultdict +from contextlib import contextmanager +from itertools import chain +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from mmpretrain.utils import require + + +@require('torch>=1.9.0', 'https://pytorch.org/get-started/locally/') +@require('accelerate') +def dispatch_model( + model, + device_map: Union[str, dict], + max_memory: Optional[dict] = None, + no_split_module_classes: Optional[List[str]] = None, + offload_folder: str = None, + offload_buffers: bool = False, + preload_module_classes: Optional[List[str]] = None, +): + """Split and dispatch a model across devices. + + The function depends on the `accelerate` package. Refers to + https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling + + Args: + model (torch.nn.Module): The model to dispatch. + device_map (str | dict | None): A map that specifies where each + submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every + submodule of it will be sent to the same device. You can use + `device_map="auto"` to automatically generate the device map. + Defaults to None. + max_memory (dict | None): A dictionary device identifier to maximum + memory. Will default to the maximum memory available for each GPU + and the available CPU RAM if unset. Defaults to None. + no_split_module_classes (List[str] | None): A list of layer class names + that should never be split across device (for instance any layer + that has a residual connection). If None, try to get the settings + from the model class. Defaults to None. + offload_folder (str | None): If the `device_map` contains any value + `"disk"`, the folder where we will offload weights. + offload_buffers (bool): In the layers that are offloaded on the CPU + or the hard drive, whether or not to offload the buffers as + well as the parameters. Defaults to False. + preload_module_classes (List[str] | None): A list of classes whose + instances should load all their weights (even in the submodules) at + the beginning of the forward. This should only be used for classes + that have submodules which are registered but not called directly + during the forward, for instance if a `dense` linear layer is + registered, but at forward, `dense.weight` and `dense.bias` are + used in some operations instead of calling `dense` directly. + Defaults to None. + """ + from accelerate import dispatch_model, infer_auto_device_map + + # Check valid device_map string. + valid_map_option = ['auto', 'balanced', 'balanced_low_0', 'sequential'] + if isinstance(device_map, str) and device_map not in valid_map_option: + raise ValueError('If passing a string for `device_map`, please choose ' + f'from {valid_map_option}.') + + # Generate device map automatically + if isinstance(device_map, str): + if no_split_module_classes is None: + no_split_module_classes = getattr(model, '_no_split_modules', None) + if no_split_module_classes is None: + raise ValueError(f'{model.__class__.__name__} does not support ' + f"`device_map='{device_map}'` yet.") + + if device_map != 'sequential': + from accelerate.utils import get_balanced_memory + max_memory = get_balanced_memory( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + dtype=None, + low_zero=(device_map == 'balanced_low_0'), + ) + max_memory[0] *= 0.9 + device_map = infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + dtype=None, + ) + + if 'disk' in device_map.values(): + if offload_folder is None: + raise ValueError( + 'The current `device_map` had weights offloaded to the disk. ' + 'Please provide an `offload_folder` for them.') + os.makedirs(offload_folder, exist_ok=True) + + main_device = next( + (d for d in device_map.values() if d not in ['cpu', 'disk']), 'cpu') + + model = dispatch_model( + model, + device_map=device_map, + main_device=main_device, + offload_dir=offload_folder, + offload_buffers=offload_buffers, + preload_module_classes=preload_module_classes, + ) + if hasattr(model, 'data_preprocessor'): + model.data_preprocessor._device = torch.device(main_device) + return model + + +@contextmanager +def init_empty_weights(include_buffers: bool = False): + """A context manager under which models are initialized with all parameters + on the meta device. + + With this context manager, we can create an empty model. Useful when just + initializing the model would blow the available RAM. + + Besides move the parameters to meta device, this method will also avoid + load checkpoint from `mmengine.runner.load_checkpoint` and + `transformers.PreTrainedModel.from_pretrained`. + + Modified from https://github.com/huggingface/accelerate + + Args: + include_buffers (bool): Whether put all buffers on the meta device + during initialization. + """ + device = torch.device('meta') + + # move parameter and buffer to meta device + old_register_parameter = nn.Module.register_parameter + if include_buffers: + old_register_buffer = nn.Module.register_buffer + # See https://github.com/huggingface/accelerate/pull/699 + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ['empty', 'zeros', 'ones', 'full'] + } + + def register_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs) + + def register_buffer(module, name, buffer, *args, **kwargs): + old_register_buffer(module, name, buffer, *args, **kwargs) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + def patch_tensor_constructor(fn): + + def wrapper(*args, **kwargs): + kwargs['device'] = device + return fn(*args, **kwargs) + + return wrapper + + # Patch load_checkpoint + import mmengine.runner.checkpoint as mmengine_load + old_load_checkpoint = mmengine_load.load_checkpoint + + def patch_load_checkpoint(*args, **kwargs): + return {} + + # Patch transformers from pretrained + try: + from transformers import PreTrainedModel + from transformers.models.auto.auto_factory import (AutoConfig, + _BaseAutoModelClass) + with_transformers = True + except ImportError: + with_transformers = False + + @classmethod + def patch_auto_model(cls, pretrained_model_name_or_path, *model_args, + **kwargs): + cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, + *model_args, **kwargs) + return cls.from_config(cfg) + + @classmethod + def patch_pretrained_model(cls, pretrained_model_name_or_path, *model_args, + **kwargs): + cfg = cls.config_class.from_pretrained(pretrained_model_name_or_path, + *model_args, **kwargs) + return cls(cfg) + + if with_transformers: + old_pretrained_model = PreTrainedModel.from_pretrained + old_auto_model = _BaseAutoModelClass.from_pretrained + + try: + nn.Module.register_parameter = register_parameter + mmengine_load.load_checkpoint = patch_load_checkpoint + if with_transformers: + PreTrainedModel.from_pretrained = patch_pretrained_model + _BaseAutoModelClass.from_pretrained = patch_auto_model + if include_buffers: + nn.Module.register_buffer = register_buffer + for func in tensor_constructors_to_patch.keys(): + tensor_constructor = patch_tensor_constructor( + getattr(torch, func)) + setattr(torch, func, tensor_constructor) + yield + finally: + nn.Module.register_parameter = old_register_parameter + mmengine_load.load_checkpoint = old_load_checkpoint + if with_transformers: + PreTrainedModel.from_pretrained = old_pretrained_model + _BaseAutoModelClass.from_pretrained = old_auto_model + if include_buffers: + nn.Module.register_buffer = old_register_buffer + for func, ori in tensor_constructors_to_patch.items(): + setattr(torch, func, ori) + + +def compute_module_sizes( + model: nn.Module, + dtype: Union[str, torch.dtype, None] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None): + """Compute the size of each submodule of a given model.""" + + def get_dtype(dtype): + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + if dtype is not None: + assert issubclass(dtype, torch.dtype) + return dtype + + def dtype_bytes(dtype: torch.dtype): + if dtype is torch.bool: + return 1 + if dtype.is_floating_point: + return torch.finfo(dtype).bits / 8 + else: + return torch.iinfo(dtype).bits / 8 + + if dtype is not None: + dtype = get_dtype(dtype) + dtype_size = dtype_bytes(dtype) + + if special_dtypes is not None: + special_dtypes = { + key: dtype_bytes(dtype) + for key, dtype in special_dtypes.items() + } + + module_sizes = defaultdict(int) + for name, tensor in chain( + model.named_parameters(recurse=True), + model.named_buffers(recurse=True)): + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes[name] + elif dtype is None: + size = tensor.numel() * tensor.element_size() + else: + size = tensor.numel() * min(dtype_size, tensor.element_size()) + name_parts = name.split('.') + for idx in range(len(name_parts) + 1): + module_sizes['.'.join(name_parts[:idx])] += size + + return module_sizes diff --git a/mmpretrain/apis/visual_grounding.py b/mmpretrain/apis/visual_grounding.py new file mode 100644 index 0000000..0153d56 --- /dev/null +++ b/mmpretrain/apis/visual_grounding.py @@ -0,0 +1,182 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer +from .model import list_models + + +class VisualGroundingInferencer(BaseInferencer): + """The inferencer for visual grounding. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``VisualGroundingInferencer.list_models()`` and you can also + query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import VisualGroundingInferencer + >>> inferencer = VisualGroundingInferencer('ofa-base_3rdparty_refcoco') + >>> inferencer('demo/cat-dog.png', 'dog')[0] + {'pred_bboxes': tensor([[ 36.6000, 29.6000, 355.8000, 395.2000]])} + """ # noqa: E501 + + visualize_kwargs: set = { + 'resize', 'show', 'show_dir', 'wait_time', 'line_width', 'bbox_color' + } + + def __call__(self, + images: Union[str, np.ndarray, list], + texts: Union[str, list], + return_datasamples: bool = False, + batch_size: int = 1, + **kwargs) -> dict: + """Call the inferencer. + + Args: + images (str | array | list): The image path or array, or a list of + images. + texts (str | list): The text to do visual grounding. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + line_width (int): The line width of the bbox. Defaults to 3. + bbox_color (str | tuple): The color of the bbox. + Defaults to 'green'. + + Returns: + list: The inference results. + """ + if not isinstance(images, (list, tuple)): + assert isinstance(texts, str) + inputs = [{'img': images, 'text': texts}] + else: + inputs = [] + for i in range(len(images)): + input_ = {'img': images[i], 'text': texts[i]} + inputs.append(input_) + + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[dict], batch_size: int = 1): + + def load_image(input_: dict): + img = imread(input_['img']) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return {**input_, 'img': img} + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[dict], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + line_width: int = 3, + bbox_color: Union[str, tuple] = 'green', + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_['img']) + if isinstance(input_['img'], str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_['img']).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_visual_grounding( + image, + data_sample, + resize=resize, + show=show, + wait_time=wait_time, + line_width=line_width, + bbox_color=bbox_color, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + results.append({'pred_bboxes': data_sample.get('pred_bboxes')}) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Visual Grounding') diff --git a/mmpretrain/apis/visual_question_answering.py b/mmpretrain/apis/visual_question_answering.py new file mode 100644 index 0000000..616e1ed --- /dev/null +++ b/mmpretrain/apis/visual_question_answering.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +from mmcv.image import imread +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample +from .base import BaseInferencer +from .model import list_models + + +class VisualQuestionAnsweringInferencer(BaseInferencer): + """The inferencer for visual question answering. + + Args: + model (BaseModel | str | Config): A model name or a path to the config + file, or a :obj:`BaseModel` object. The model name can be found + by ``VisualQuestionAnsweringInferencer.list_models()`` and you can + also query it in :doc:`/modelzoo_statistics`. + pretrained (str, optional): Path to the checkpoint. If None, it will + try to find a pre-defined weight from the model you specified + (only work if the ``model`` is a model name). Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + **kwargs: Other keyword arguments to initialize the model (only work if + the ``model`` is a model name). + + Example: + >>> from mmpretrain import VisualQuestionAnsweringInferencer + >>> inferencer = VisualQuestionAnsweringInferencer('ofa-base_3rdparty-zeroshot_vqa') + >>> inferencer('demo/cat-dog.png', "What's the animal next to the dog?")[0] + {'question': "What's the animal next to the dog?", 'pred_answer': 'cat'} + """ # noqa: E501 + + visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'} + + def __call__(self, + images: Union[str, np.ndarray, list], + questions: Union[str, list], + return_datasamples: bool = False, + batch_size: int = 1, + objects: Optional[List[str]] = None, + **kwargs) -> dict: + """Call the inferencer. + + Args: + images (str | array | list): The image path or array, or a list of + images. + questions (str | list): The question to the correspondding image. + return_datasamples (bool): Whether to return results as + :obj:`DataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + objects (List[List[str]], optional): Some algorithms like OFA + fine-tuned VQA models requires extra object description list + for every image. Defaults to None. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + show (bool): Whether to display the visualization result in a + window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + show_dir (str, optional): If not None, save the visualization + results in the specified directory. Defaults to None. + + Returns: + list: The inference results. + """ + if not isinstance(images, (list, tuple)): + assert isinstance(questions, str) + inputs = [{'img': images, 'question': questions}] + if objects is not None: + assert isinstance(objects[0], str) + inputs[0]['objects'] = objects + else: + inputs = [] + for i in range(len(images)): + input_ = {'img': images[i], 'question': questions[i]} + if objects is not None: + input_['objects'] = objects[i] + inputs.append(input_) + + return super().__call__(inputs, return_datasamples, batch_size, + **kwargs) + + def _init_pipeline(self, cfg: Config) -> Callable: + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + from mmpretrain.datasets import remove_transform + + # Image loading is finished in `self.preprocess`. + test_pipeline_cfg = remove_transform(test_pipeline_cfg, + 'LoadImageFromFile') + test_pipeline = Compose( + [TRANSFORMS.build(t) for t in test_pipeline_cfg]) + return test_pipeline + + def preprocess(self, inputs: List[dict], batch_size: int = 1): + + def load_image(input_: dict): + img = imread(input_['img']) + if img is None: + raise ValueError(f'Failed to read image {input_}.') + return {**input_, 'img': img} + + pipeline = Compose([load_image, self.pipeline]) + + chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size) + yield from map(default_collate, chunked_data) + + def visualize(self, + ori_inputs: List[dict], + preds: List[DataSample], + show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, + show_dir=None): + if not show and show_dir is None: + return None + + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_['img']) + if isinstance(input_['img'], str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_['img']).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_vqa( + image, + data_sample, + resize=resize, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization + + def postprocess(self, + preds: List[DataSample], + visualization: List[np.ndarray], + return_datasamples=False) -> dict: + if return_datasamples: + return preds + + results = [] + for data_sample in preds: + results.append({ + 'question': data_sample.get('question'), + 'pred_answer': data_sample.get('pred_answer'), + }) + + return results + + @staticmethod + def list_models(pattern: Optional[str] = None): + """List all available model names. + + Args: + pattern (str | None): A wildcard pattern to match model names. + + Returns: + List[str]: a list of model names. + """ + return list_models(pattern=pattern, task='Visual Question Answering') diff --git a/mmpretrain/configs/_base_/datasets/cifar10_bs16.py b/mmpretrain/configs/_base_/datasets/cifar10_bs16.py new file mode 100644 index 0000000..3737dbe --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/cifar10_bs16.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import CIFAR10, PackInputs, RandomCrop, RandomFlip +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = CIFAR10 +data_preprocessor = dict( + num_classes=10, + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # loaded images are already RGB format + to_rgb=False) + +train_pipeline = [ + dict(type=RandomCrop, crop_size=32, padding=4), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/cifar10', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/cifar10/', + split='test', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/cub_bs8_384.py b/mmpretrain/configs/_base_/datasets/cub_bs8_384.py new file mode 100644 index 0000000..b193bf8 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/cub_bs8_384.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CUB, CenterCrop, LoadImageFromFile, + PackInputs, RandomCrop, RandomFlip, Resize) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = CUB +data_preprocessor = dict( + num_classes=200, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=Resize, scale=510), + dict(type=RandomCrop, crop_size=384), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=Resize, scale=510), + dict(type=CenterCrop, crop_size=384), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=8, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/CUB_200_2011', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=8, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/CUB_200_2011', + split='test', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet21k_bs128.py b/mmpretrain/configs/_base_/datasets/imagenet21k_bs128.py new file mode 100644 index 0000000..11c4c0a --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet21k_bs128.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (ImageNet21k, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop) + +# dataset settings +dataset_type = ImageNet21k +data_preprocessor = dict( + num_classes=21842, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=128, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet21k', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py b/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py new file mode 100644 index 0000000..cf0aa62 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (AutoAugment, CenterCrop, ImageNet, + LoadImageFromFile, PackInputs, RandomErasing, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in bgr_mean])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=128, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=128, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs256_beitv2.py b/mmpretrain/configs/_base_/datasets/imagenet_bs256_beitv2.py new file mode 100644 index 0000000..f89eb17 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs256_beitv2.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler, default_collate + +from mmpretrain.datasets import (BEiTMaskGenerator, ColorJitter, ImageNet, + LoadImageFromFile, PackInputs, RandomFlip, + RandomResizedCropAndInterpolationWithTwoPic) +from mmpretrain.models import TwoNormDataPreprocessor + +dataset_type = ImageNet +data_root = 'data/imagenet/' + +data_preprocessor = dict( + type=TwoNormDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + second_mean=[127.5, 127.5, 127.5], + second_std=[127.5, 127.5, 127.5], + to_rgb=True) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ColorJitter, brightness=0.4, contrast=0.4, saturation=0.4, + hue=0.), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandomResizedCropAndInterpolationWithTwoPic, + size=224, + second_size=224, + interpolation='bicubic', + second_interpolation='bicubic', + scale=(0.2, 1.0)), + dict( + type=BEiTMaskGenerator, + input_size=(14, 14), + num_masking_patches=75, + max_num_patches=75, + min_num_patches=16), + dict(type=PackInputs) +] + +train_dataloader = dict( + batch_size=256, + num_workers=8, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate), + dataset=dict( + type=dataset_type, + data_root=data_root, + split='train', + pipeline=train_pipeline)) diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32.py new file mode 100644 index 0000000..7d07400 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandomFlip, RandomResizedCrop, + ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py new file mode 100644 index 0000000..f911bc2 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandomFlip, RandomResizedCrop, + ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py new file mode 100644 index 0000000..29b698f --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmcv.transforms import (LoadImageFromFile, RandomApply, RandomFlip, + RandomGrayscale) +from mmengine.dataset import DefaultSampler, default_collate + +from mmpretrain.datasets import (ColorJitter, GaussianBlur, ImageNet, + MultiView, PackInputs, RandomResizedCrop) +from mmpretrain.models import SelfSupDataPreprocessor + +# dataset settings +dataset_type = ImageNet +data_root = 'data/imagenet/' +data_preprocessor = dict( + type=SelfSupDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True) + +view_pipeline = [ + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomApply, + transforms=[ + dict( + type=ColorJitter, + brightness=0.8, + contrast=0.8, + saturation=0.8, + hue=0.2) + ], + prob=0.8), + dict( + type=RandomGrayscale, + prob=0.2, + keep_channels=True, + channel_weights=(0.114, 0.587, 0.2989)), + dict( + type=GaussianBlur, + magnitude_range=(0.1, 2.0), + magnitude_std='inf', + prob=0.5), +] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=MultiView, num_views=2, transforms=[view_pipeline]), + dict(type=PackInputs) +] + +train_dataloader = dict( + batch_size=32, + num_workers=4, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate), + dataset=dict( + type=ImageNet, + data_root=data_root, + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline)) diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py b/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py new file mode 100644 index 0000000..017f5b7 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmcv.transforms import LoadImageFromFile, RandomFlip +from mmengine.dataset.sampler import DefaultSampler + +from mmpretrain.datasets import ImageNet, PackInputs, RandomResizedCrop +from mmpretrain.models import SelfSupDataPreprocessor + +# dataset settings +dataset_type = ImageNet +data_root = 'data/imagenet/' +data_preprocessor = dict( + type=SelfSupDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + crop_ratio_range=(0.2, 1.0), + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5), + dict(type=PackInputs) +] + +train_dataloader = dict( + batch_size=512, + num_workers=8, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type=dataset_type, + data_root=data_root, + split='train', + pipeline=train_pipeline)) diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize.py new file mode 100644 index 0000000..a2d8aea --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandomFlip, RandomResizedCrop, + ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py new file mode 100644 index 0000000..a5f0526 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandomFlip, RandomResizedCrop, + ResizeEdge) +from mmpretrain.datasets.transforms import AutoAugment +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_224.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_224.py new file mode 100644 index 0000000..5a38943 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_224.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandAugment, RandomErasing, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandAugment, + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict( + type=RandomErasing, + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_256.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_256.py new file mode 100644 index 0000000..9690ff8 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_256.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandAugment, RandomErasing, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=256, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandAugment, + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict( + type=RandomErasing, + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=292, # ( 256 / 224 * 256 ) + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=256), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py new file mode 100644 index 0000000..85aeb1e --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (ImageNet, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop, Resize) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=384, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=Resize, scale=384, backend='pillow', interpolation='bicubic'), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/default_runtime.py b/mmpretrain/configs/_base_/default_runtime.py new file mode 100644 index 0000000..b5c748e --- /dev/null +++ b/mmpretrain/configs/_base_/default_runtime.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.visualization import LocalVisBackend + +from mmpretrain.engine.hooks import VisualizationHook +from mmpretrain.visualization import UniversalVisualizer + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + + # print log every 100 iterations. + logger=dict(type=LoggerHook, interval=100), + + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1), + + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), + + # validation results visualization, set True to enable it. + visualization=dict(type=VisualizationHook, enable=False), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +vis_backends = [dict(type=LocalVisBackend)] +visualizer = dict(type=UniversalVisualizer, vis_backends=vis_backends) + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# Do not need to specify default_scope with new config. Therefore set it to +# None to avoid BC-breaking. +default_scope = None diff --git a/mmpretrain/configs/_base_/models/convnext_base.py b/mmpretrain/configs/_base_/models/convnext_base.py new file mode 100644 index 0000000..6315b2f --- /dev/null +++ b/mmpretrain/configs/_base_/models/convnext_base.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.model import TruncNormalInit + +from mmpretrain.models import (ConvNeXt, CutMix, ImageClassifier, + LabelSmoothLoss, LinearClsHead, Mixup) + +# Model settings +model = dict( + type=ImageClassifier, + backbone=dict(type=ConvNeXt, arch='base', drop_path_rate=0.5), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1024, + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + init_cfg=None, + ), + init_cfg=dict( + type=TruncNormalInit, layer=['Conv2d', 'Linear'], std=.02, bias=0.), + train_cfg=dict(augments=[ + dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0), + ]), +) diff --git a/mmpretrain/configs/_base_/models/mae_hivit_base_p16.py b/mmpretrain/configs/_base_/models/mae_hivit_base_p16.py new file mode 100644 index 0000000..975e16b --- /dev/null +++ b/mmpretrain/configs/_base_/models/mae_hivit_base_p16.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (MAE, MAEHiViT, MAEPretrainDecoder, + MAEPretrainHead, PixelReconstructionLoss) + +# model settings +model = dict( + type=MAE, + backbone=dict(type=MAEHiViT, patch_size=16, arch='base', mask_ratio=0.75), + neck=dict( + type=MAEPretrainDecoder, + patch_size=16, + in_chans=3, + embed_dim=512, + decoder_embed_dim=512, + decoder_depth=6, + decoder_num_heads=16, + mlp_ratio=4., + ), + head=dict( + type=MAEPretrainHead, + norm_pix=True, + patch_size=16, + loss=dict(type=PixelReconstructionLoss, criterion='L2')), + init_cfg=[ + dict(type='Xavier', layer='Linear', distribution='uniform'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]) diff --git a/mmpretrain/configs/_base_/models/mae_vit_base_p16.py b/mmpretrain/configs/_base_/models/mae_vit_base_p16.py new file mode 100644 index 0000000..9347d1e --- /dev/null +++ b/mmpretrain/configs/_base_/models/mae_vit_base_p16.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (MAE, MAEPretrainDecoder, MAEPretrainHead, + MAEViT, PixelReconstructionLoss) + +# model settings +model = dict( + type=MAE, + backbone=dict(type=MAEViT, arch='b', patch_size=16, mask_ratio=0.75), + neck=dict( + type=MAEPretrainDecoder, + patch_size=16, + in_chans=3, + embed_dim=768, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4., + ), + head=dict( + type=MAEPretrainHead, + norm_pix=True, + patch_size=16, + loss=dict(type=PixelReconstructionLoss, criterion='L2')), + init_cfg=[ + dict(type='Xavier', layer='Linear', distribution='uniform'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]) diff --git a/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py b/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py new file mode 100644 index 0000000..17dbb9f --- /dev/null +++ b/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, LinearClsHead, MobileNetV2) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict(type=MobileNetV2, widen_factor=1.0), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1280, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5), + )) diff --git a/mmpretrain/configs/_base_/models/mobilenet_v3_small.py b/mmpretrain/configs/_base_/models/mobilenet_v3_small.py new file mode 100644 index 0000000..83edab5 --- /dev/null +++ b/mmpretrain/configs/_base_/models/mobilenet_v3_small.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.model.weight_init import NormalInit +from torch.nn.modules.activation import Hardswish + +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, MobileNetV3, + StackedLinearClsHead) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict(type=MobileNetV3, arch='small'), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=StackedLinearClsHead, + num_classes=1000, + in_channels=576, + mid_channels=[1024], + dropout_rate=0.2, + act_cfg=dict(type=Hardswish), + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + init_cfg=dict( + type=NormalInit, layer='Linear', mean=0., std=0.01, bias=0.), + topk=(1, 5))) diff --git a/mmpretrain/configs/_base_/models/resnet18.py b/mmpretrain/configs/_base_/models/resnet18.py new file mode 100644 index 0000000..30b8f65 --- /dev/null +++ b/mmpretrain/configs/_base_/models/resnet18.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, LinearClsHead, ResNet) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=ResNet, + depth=18, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=512, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5), + )) diff --git a/mmpretrain/configs/_base_/models/swin_transformer_base.py b/mmpretrain/configs/_base_/models/swin_transformer_base.py new file mode 100644 index 0000000..c73c254 --- /dev/null +++ b/mmpretrain/configs/_base_/models/swin_transformer_base.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, LinearClsHead, SwinTransformer) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=SwinTransformer, + arch='base', + img_size=384, + stage_cfgs=dict(block_cfgs=dict(window_size=12))), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1024, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5))) diff --git a/mmpretrain/configs/_base_/models/swin_transformer_v2_base.py b/mmpretrain/configs/_base_/models/swin_transformer_v2_base.py new file mode 100644 index 0000000..c7566b5 --- /dev/null +++ b/mmpretrain/configs/_base_/models/swin_transformer_v2_base.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (GlobalAveragePooling, ImageClassifier, + LabelSmoothLoss, LinearClsHead, + SwinTransformerV2) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=SwinTransformerV2, arch='base', img_size=384, drop_path_rate=0.2), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1024, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + cal_acc=False)) diff --git a/mmpretrain/configs/_base_/models/vit_base_p16.py b/mmpretrain/configs/_base_/models/vit_base_p16.py new file mode 100644 index 0000000..326c50a --- /dev/null +++ b/mmpretrain/configs/_base_/models/vit_base_p16.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.model.weight_init import KaimingInit + +from mmpretrain.models import (ImageClassifier, LabelSmoothLoss, + VisionTransformer, VisionTransformerClsHead) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=VisionTransformer, + arch='b', + img_size=224, + patch_size=16, + drop_rate=0.1, + init_cfg=[ + dict( + type=KaimingInit, + layer='Conv2d', + mode='fan_in', + nonlinearity='linear') + ]), + neck=None, + head=dict( + type=VisionTransformerClsHead, + num_classes=1000, + in_channels=768, + loss=dict( + type=LabelSmoothLoss, label_smooth_val=0.1, mode='classy_vision'), + )) diff --git a/mmpretrain/configs/_base_/schedules/cifar10_bs128.py b/mmpretrain/configs/_base_/schedules/cifar10_bs128.py new file mode 100644 index 0000000..8ab749e --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/cifar10_bs128.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import MultiStepLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001)) +# learning policy +param_scheduler = dict( + type=MultiStepLR, by_epoch=True, milestones=[100, 150], gamma=0.1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=128) diff --git a/mmpretrain/configs/_base_/schedules/cub_bs64.py b/mmpretrain/configs/_base_/schedules/cub_bs64.py new file mode 100644 index 0000000..2ca40bf --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/cub_bs64.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict( + type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005, nesterov=True)) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type=LinearLR, + start_factor=0.01, + by_epoch=True, + begin=0, + end=5, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict( + type=CosineAnnealingLR, + T_max=95, + by_epoch=True, + begin=5, + end=100, + ) +] + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=64) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py b/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py new file mode 100644 index 0000000..60ccaa0 --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +# for batch in each gpu is 128, 8 gpu +# lr = 5e-4 * 128 * 8 / 512 = 0.001 +optim_wrapper = dict( + optimizer=dict( + type=AdamW, + lr=5e-4 * 1024 / 512, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999)), + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + flat_decay_mult=0.0, + custom_keys={ + '.absolute_pos_embed': dict(decay_mult=0.0), + '.relative_position_bias_table': dict(decay_mult=0.0) + }), +) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type=LinearLR, + start_factor=1e-3, + by_epoch=True, + end=20, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict(type=CosineAnnealingLR, eta_min=1e-5, by_epoch=True, begin=20) +] + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs256.py b/mmpretrain/configs/_base_/schedules/imagenet_bs256.py new file mode 100644 index 0000000..95afa2a --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs256.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import MultiStepLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + type=MultiStepLR, by_epoch=True, milestones=[30, 60, 90], gamma=0.1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=256) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py b/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py new file mode 100644 index 0000000..9d245eb --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import StepLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.045, momentum=0.9, weight_decay=0.00004)) + +# learning policy +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=1, gamma=0.98) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=256) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs4096_adamw.py b/mmpretrain/configs/_base_/schedules/imagenet_bs4096_adamw.py new file mode 100644 index 0000000..4561f23 --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs4096_adamw.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=AdamW, lr=0.003, weight_decay=0.3), + # specific to vit pretrain + paramwise_cfg=dict(custom_keys={ + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0) + }), +) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=30, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict( + type=CosineAnnealingLR, + T_max=270, + by_epoch=True, + begin=30, + end=300, + ) +] + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py b/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py new file mode 100644 index 0000000..0c7e617 --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop + +from mmpretrain.engine.optimizers.lars import LARS + +# optimizer wrapper +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=LARS, lr=4.8, weight_decay=1e-6, momentum=0.9)) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict(type=CosineAnnealingLR, T_max=190, by_epoch=True, begin=10, end=200) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=200) diff --git a/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py b/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py new file mode 100644 index 0000000..fe9c329 --- /dev/null +++ b/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.default_runtime import * + +from mmengine.dataset import DefaultSampler, default_collate +from mmengine.hooks import CheckpointHook +from mmengine.model import ConstantInit, PretrainedInit, TruncNormalInit +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from mmengine.runner import EpochBasedTrainLoop +from torch.optim import AdamW + +from mmpretrain.datasets import (BEiTMaskGenerator, ColorJitter, ImageNet, + LoadImageFromFile, PackInputs, RandomFlip, + RandomResizedCropAndInterpolationWithTwoPic) +from mmpretrain.models import (BEiT, BEiTPretrainViT, BEiTV1Head, + CrossEntropyLoss, DALLEEncoder, + TwoNormDataPreprocessor) + +# dataset settings +dataset_type = ImageNet +data_root = 'data/imagenet/' +data_preprocessor = dict( + type=TwoNormDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + second_mean=[-31.875, -31.875, -31.875], + second_std=[318.75, 318.75, 318.75], + to_rgb=True) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ColorJitter, brightness=0.4, contrast=0.4, saturation=0.4, + hue=0.), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandomResizedCropAndInterpolationWithTwoPic, + size=224, + second_size=112, + interpolation='bicubic', + second_interpolation='lanczos', + scale=(0.08, 1.0)), + dict( + type=BEiTMaskGenerator, + input_size=(14, 14), + num_masking_patches=75, + max_num_patches=None, + min_num_patches=16), + dict(type=PackInputs) +] +train_dataloader = dict( + batch_size=256, + num_workers=8, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline)) + +# model settings +model = dict( + type=BEiT, + backbone=dict( + type=BEiTPretrainViT, + arch='base', + patch_size=16, + drop_path_rate=0.1, + final_norm=True, + out_type='raw', + layer_scale_init_value=0.1, + init_cfg=[ + dict(type=TruncNormalInit, std=0.02, layer='Linear'), + dict(type=TruncNormalInit, std=0.02, layer='Conv2d'), + dict(type=ConstantInit, layer='LayerNorm', val=1.0, bias=0.0) + ]), + neck=None, + head=dict( + type=BEiTV1Head, + embed_dims=768, + num_embed=8192, + loss=dict(type=CrossEntropyLoss)), + target_generator=dict( + type=DALLEEncoder, + init_cfg=dict( + type=PretrainedInit, + checkpoint= # noqa: E251 + 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/dalle_encoder.pth', # noqa: E501 + ))) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, lr=1.5e-3, betas=(0.9, 0.999), weight_decay=0.05), + clip_grad=dict(max_norm=3.0), + paramwise_cfg=dict( + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=1e-5, + by_epoch=True, + begin=10, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300) +default_hooks.update( + # only keeps the latest 3 checkpoints + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3)) + +randomness.update(seed=0, diff_rank_seed=True) + +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=2048) diff --git a/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py b/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py new file mode 100644 index 0000000..00a76b7 --- /dev/null +++ b/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from ..._base_.datasets.imagenet_bs64_swin_224 import * + from ..._base_.schedules.imagenet_bs1024_adamw_swin import * + from ..._base_.default_runtime import * + +from mmengine.hooks import CheckpointHook +from mmengine.model import PretrainedInit, TruncNormalInit +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +from mmpretrain.datasets import LoadImageFromFile, PackInputs, RandomFlip +from mmpretrain.engine.optimizers import \ + LearningRateDecayOptimWrapperConstructor +from mmpretrain.models import (BEiTViT, ImageClassifier, LabelSmoothLoss, + LinearClsHead) +from mmpretrain.models.utils.batch_augments import CutMix, Mixup + +data_preprocessor = dict( + num_classes=1000, + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + to_rgb=True, +) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=BEiTViT, + arch='base', + img_size=224, + patch_size=16, + drop_path_rate=0.1, + out_type='avg_featmap', + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + init_cfg=dict(type=PretrainedInit, checkpoint='', prefix='backbone.')), + neck=None, + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=768, + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + init_cfg=[dict(type=TruncNormalInit, layer='Linear', std=0.02)]), + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandAugment, + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')), + dict( + type=RandomErasing, + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=0.3333333333333333, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs) +] + +train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +# optimizer wrapper +optim_wrapper = dict( + optimizer=dict(type=AdamW, lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)), + constructor=LearningRateDecayOptimWrapperConstructor, + paramwise_cfg=dict( + _delete_=True, + layer_decay_rate=0.65, + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=20, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + by_epoch=True, + begin=20, + end=100, + eta_min=1e-6, + convert_to_iter_based=True) +] + +# runtime settings +default_hooks = dict( + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=2)) + +train_cfg = dict(by_epoch=True, max_epochs=100) + +randomness = dict(seed=0) diff --git a/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb64_in1k.py b/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb64_in1k.py new file mode 100644 index 0000000..b4718af --- /dev/null +++ b/mmpretrain/configs/beit/benchmarks/beit-base-p16_8xb64_in1k.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from ..._base_.datasets.imagenet_bs64_swin_224 import * + from ..._base_.schedules.imagenet_bs1024_adamw_swin import * + from ..._base_.default_runtime import * + +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import (BEiTViT, ImageClassifier, LabelSmoothLoss, + LinearClsHead) +from mmpretrain.models.utils.batch_augments import CutMix, Mixup + +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + # convert image from BGR to RGB + to_rgb=True, +) + +model = dict( + type=ImageClassifier, + backbone=dict( + type=BEiTViT, + arch='base', + img_size=224, + patch_size=16, + out_type='avg_featmap', + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + ), + neck=None, + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=768, + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=.02), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-1600e_in1k.py b/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-1600e_in1k.py new file mode 100644 index 0000000..6bec16b --- /dev/null +++ b/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-1600e_in1k.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs256_beitv2 import * + from .._base_.default_runtime import * + +from mmengine.model import ConstantInit, PretrainedInit, TruncNormalInit +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from mmengine.runner import EpochBasedTrainLoop +from torch.optim import AdamW + +from mmpretrain.models import (VQKD, BEiT, BEiTPretrainViT, BEiTV2Head, + BEiTV2Neck, CrossEntropyLoss) + +vqkd_encoder = dict( + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='featmap', + with_cls_token=True, + frozen_stages=-1, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + layer_scale_init_value=0., + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None) + +layer_scale_init_value = 0.1 +drop_path_rate = 0.1 # 0. for 300 epochs and 0.1 for 1600 epochs. + +model = dict( + type=BEiT, + backbone=dict( + type=BEiTPretrainViT, + arch='base', + patch_size=16, + out_indices=[-4, -1], + drop_path_rate=drop_path_rate, + final_norm=False, + out_type='raw', + layer_scale_init_value=layer_scale_init_value, + init_cfg=[ + dict(type=TruncNormalInit, std=0.02, layer='Linear'), + dict(type=TruncNormalInit, std=0.02, layer='Conv2d'), + dict(type=ConstantInit, layer='LayerNorm', val=1.0, bias=0.0) + ]), + neck=dict( + type=BEiTV2Neck, + num_layers=2, + early_layers=9, + backbone_arch='base', + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + ), + head=dict( + type=BEiTV2Head, + embed_dims=768, + num_embed=8192, + loss=dict(type=CrossEntropyLoss)), + target_generator=dict( + type=VQKD, + encoder_config=vqkd_encoder, + init_cfg=dict( + type=PretrainedInit, + checkpoint= # noqa + 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/vqkd_encoder.pth' # noqa + ))) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + # betas: (0.9, 0.98) for 300 epochs and (0.9, 0.999) for 1600 epochs. + optimizer=dict( + type=AdamW, lr=1.5e-3, betas=(0.9, 0.999), weight_decay=0.05), + clip_grad=dict(max_norm=3.0), + paramwise_cfg=dict( + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=1e-5, + by_epoch=True, + begin=10, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +default_hooks = dict( + # only keeps the latest 3 checkpoints + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3)) + +randomness = dict(seed=0, diff_rank_seed=True) + +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=2048) diff --git a/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k.py b/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k.py new file mode 100644 index 0000000..3fe9b50 --- /dev/null +++ b/mmpretrain/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs256_beitv2 import * + from .._base_.default_runtime import * + +from mmengine.model import ConstantInit, PretrainedInit, TruncNormalInit +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from mmengine.runner import EpochBasedTrainLoop +from torch.optim import AdamW + +from mmpretrain.models import (VQKD, BEiT, BEiTPretrainViT, BEiTV2Head, + BEiTV2Neck, CrossEntropyLoss) + +# model settings +vqkd_encoder = dict( + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='featmap', + with_cls_token=True, + frozen_stages=-1, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + layer_scale_init_value=0., + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None) + +layer_scale_init_value = 0.1 +drop_path_rate = 0. # 0. for 300 epochs and 0.1 for 1600 epochs. +model = dict( + type=BEiT, + backbone=dict( + type=BEiTPretrainViT, + arch='base', + patch_size=16, + out_indices=[-4, -1], + drop_path_rate=drop_path_rate, + final_norm=False, + out_type='raw', + layer_scale_init_value=layer_scale_init_value, + init_cfg=[ + dict(type=TruncNormalInit, std=0.02, layer='Linear'), + dict(type=TruncNormalInit, std=0.02, layer='Conv2d'), + dict(type=ConstantInit, layer='LayerNorm', val=1.0, bias=0.0) + ]), + neck=dict( + type=BEiTV2Neck, + num_layers=2, + early_layers=9, + backbone_arch='base', + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + ), + head=dict( + type=BEiTV2Head, + embed_dims=768, + num_embed=8192, + loss=dict(type=CrossEntropyLoss)), + target_generator=dict( + type=VQKD, + encoder_config=vqkd_encoder, + init_cfg=dict( + type=PretrainedInit, + checkpoint= # noqa + 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/vqkd_encoder.pth' # noqa + ))) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + # betas: (0.9, 0.98) for 300 epochs and (0.9, 0.999) for 1600 epochs. + optimizer=dict( + type=AdamW, lr=1.5e-3, betas=(0.9, 0.98), weight_decay=0.05), + clip_grad=dict(max_norm=3.0), + paramwise_cfg=dict( + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=1e-5, + by_epoch=True, + begin=10, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300) +default_hooks = dict( + # only keeps the latest 3 checkpoints + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3)) + +randomness = dict(seed=0, diff_rank_seed=True) + +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=2048) diff --git a/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py b/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py new file mode 100644 index 0000000..ee32d3a --- /dev/null +++ b/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from ..._base_.datasets.imagenet_bs64_swin_224 import * + from ..._base_.schedules.imagenet_bs1024_adamw_swin import * + from ..._base_.default_runtime import * + +from mmengine.model import PretrainedInit, TruncNormalInit +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +from mmpretrain.engine.optimizers import \ + LearningRateDecayOptimWrapperConstructor +from mmpretrain.models import (BEiTViT, ImageClassifier, LabelSmoothLoss, + LinearClsHead) +from mmpretrain.models.utils.batch_augments import CutMix, Mixup + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=BEiTViT, + arch='base', + img_size=224, + patch_size=16, + # 0.2 for 1600 epochs pretrained models and 0.1 for 300 epochs. + drop_path_rate=0.1, + out_type='avg_featmap', + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + init_cfg=dict(type=PretrainedInit, checkpoint='', prefix='backbone.')), + neck=None, + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=768, + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + init_cfg=[dict(type=TruncNormalInit, layer='Linear', std=0.02)]), + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandAugment, + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')), + dict( + type=RandomErasing, + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=0.3333333333333333, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs) +] + +train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +# optimizer wrapper +optim_wrapper = dict( + optimizer=dict(type=AdamW, lr=5e-4, weight_decay=0.05, betas=(0.9, 0.999)), + constructor=LearningRateDecayOptimWrapperConstructor, + paramwise_cfg=dict( + _delete_=True, + # 0.6 for 1600 epochs pretrained models and 0.65 for 300 epochs + layer_decay_rate=0.65, + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=20, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + by_epoch=True, + begin=20, + end=100, + eta_min=1e-6, + convert_to_iter_based=True) +] + +# runtime settings +default_hooks = dict( + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=2)) + +train_cfg = dict(by_epoch=True, max_epochs=100) + +randomness = dict(seed=0) diff --git a/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb64_in1k.py b/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb64_in1k.py new file mode 100644 index 0000000..ec20ba9 --- /dev/null +++ b/mmpretrain/configs/beitv2/benchmarks/beit-base-p16_8xb64_in1k.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from ..._base_.datasets.imagenet_bs64_swin_224 import * + from ..._base_.schedules.imagenet_bs1024_adamw_swin import * + from ..._base_.default_runtime import * + +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import (BEiTViT, ImageClassifier, LabelSmoothLoss, + LinearClsHead) +from mmpretrain.models.utils.batch_augments.cutmix import CutMix +from mmpretrain.models.utils.batch_augments.mixup import Mixup + +model = dict( + type=ImageClassifier, + backbone=dict( + type=BEiTViT, + arch='base', + img_size=224, + patch_size=16, + out_type='avg_featmap', + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + ), + neck=None, + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=768, + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=.02), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/convnext/convnext-base_32xb128_in1k.py b/mmpretrain/configs/convnext/convnext-base_32xb128_in1k.py new file mode 100644 index 0000000..3e8a10f --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-base_32xb128_in1k.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +from mmpretrain.engine import EMAHook + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=None, +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-base_32xb128_in21k.py b/mmpretrain/configs/convnext/convnext-base_32xb128_in21k.py new file mode 100644 index 0000000..73fb0a0 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-base_32xb128_in21k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model setting +model.update(head=dict(num_classes=21841)) + +# dataset setting +data_preprocessor.update(num_classes=21841) +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-large_64xb64_in1k-384px.py b/mmpretrain/configs/convnext/convnext-large_64xb64_in1k-384px.py new file mode 100644 index 0000000..2da428a --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-large_64xb64_in1k-384px.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-large_64xb64_in1k.py b/mmpretrain/configs/convnext/convnext-large_64xb64_in1k.py new file mode 100644 index 0000000..e11e6a9 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-large_64xb64_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=None, +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=1e-4, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-large_64xb64_in21k.py b/mmpretrain/configs/convnext/convnext-large_64xb64_in21k.py new file mode 100644 index 0000000..d103dfa --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-large_64xb64_in21k.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model setting +model.update(head=dict(num_classes=21841)) + +# dataset setting +data_preprocessor.update(num_classes=21841) +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-small_32xb128_in1k-384px.py b/mmpretrain/configs/convnext/convnext-small_32xb128_in1k-384px.py new file mode 100644 index 0000000..9b7bce7 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-small_32xb128_in1k-384px.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-small_32xb128_in1k.py b/mmpretrain/configs/convnext/convnext-small_32xb128_in1k.py new file mode 100644 index 0000000..bd43ec1 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-small_32xb128_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=None, +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=1e-4, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k-384px.py b/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k-384px.py new file mode 100644 index 0000000..9b7bce7 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k-384px.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k.py b/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k.py new file mode 100644 index 0000000..bd43ec1 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-tiny_32xb128_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=None, +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=1e-4, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k-384px.py b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k-384px.py new file mode 100644 index 0000000..2da428a --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k-384px.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k.py b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k.py new file mode 100644 index 0000000..bdb1157 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# dataset setting +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=None, +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=1e-4, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in21k.py b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in21k.py new file mode 100644 index 0000000..21f10dc --- /dev/null +++ b/mmpretrain/configs/convnext/convnext-xlarge_64xb64_in21k.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +from mmpretrain.engine import EMAHook + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model setting +model.update(head=dict(num_classes=21841)) + +# dataset setting +data_preprocessor.update(num_classes=21841) +train_dataloader.update(batch_size=64) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py b/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py new file mode 100644 index 0000000..6d90e71 --- /dev/null +++ b/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.convnext_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +from mmpretrain.engine import EMAHook + +# dataset setting +train_dataloader.update(batch_size=128) + +# schedule setting +optim_wrapper.update( + optimizer=dict(lr=4e-3), + clip_grad=dict(max_norm=5.0), +) + +# runtime setting +custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py b/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py new file mode 100644 index 0000000..a254ac8 --- /dev/null +++ b/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks import CheckpointHook +from mmengine.optim import CosineAnnealingLR, LinearLR, OptimWrapper +from mmengine.runner import EpochBasedTrainLoop +from torch.optim import AdamW + +from mmpretrain.models import (EVA, CLIPGenerator, CosineSimilarityLoss, + MAEPretrainDecoder, MIMHead) + +# dataset settings +train_dataloader.batch_size = 256 + +# model settings +model.type = EVA +model.init_cfg = None +model.backbone.update(init_cfg=[ + dict(type='Xavier', distribution='uniform', layer='Linear'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) +]) +model.neck.update( + type=MAEPretrainDecoder, + predict_feature_dim=512, + init_cfg=[ + dict(type='Xavier', distribution='uniform', layer='Linear'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]) +model.head = dict( + type=MIMHead, + loss=dict(type=CosineSimilarityLoss, shift_factor=2.0, scale_factor=2.0)) +model.target_generator = dict( + type=CLIPGenerator, + tokenizer_path= # noqa + 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/clip_vit_base_16.pth.tar' # noqa +) + +# optimizer wrapper +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) +find_unused_parameters = True + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-4, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=360, + by_epoch=True, + begin=40, + end=400, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400) +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(dict(seed=0, diff_rank_seed=True)) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_1600e_in1k.py b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_1600e_in1k.py new file mode 100644 index 0000000..a32cb0c --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_1600e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=1560, + by_epoch=True, + begin=40, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_400e_in1k.py b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_400e_in1k.py new file mode 100644 index 0000000..6ffcf6d --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_400e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=360, + by_epoch=True, + begin=40, + end=400, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_800e_in1k.py b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_800e_in1k.py new file mode 100644 index 0000000..f8a49b5 --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_base_p16_8xb512_amp_coslr_800e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=760, + by_epoch=True, + begin=40, + end=800, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=800) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_1600e_in1k.py b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_1600e_in1k.py new file mode 100644 index 0000000..ae1aba5 --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_1600e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model.update( + backbone=dict(type=MAEHiViT, arch='large'), + neck=dict(type=MAEPretrainDecoder, embed_dim=768)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=1560, + by_epoch=True, + begin=40, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_400e_in1k.py b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_400e_in1k.py new file mode 100644 index 0000000..cdc1259 --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_400e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model.update( + backbone=dict(type=MAEHiViT, arch='large'), + neck=dict(type=MAEPretrainDecoder, embed_dim=768)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=360, + by_epoch=True, + begin=40, + end=400, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_800e_in1k.py b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_800e_in1k.py new file mode 100644 index 0000000..657ee01 --- /dev/null +++ b/mmpretrain/configs/mae/mae_hivit_large_p16_8xb512_amp_coslr_800e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_hivit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model.update( + backbone=dict(type=MAEHiViT, arch='large'), + neck=dict(type=MAEPretrainDecoder, embed_dim=768)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=760, + by_epoch=True, + begin=40, + end=800, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=800) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True +find_unused_parameters = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_1600e_in1k.py b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_1600e_in1k.py new file mode 100644 index 0000000..a4b325d --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_1600e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=1560, + by_epoch=True, + begin=40, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py new file mode 100644 index 0000000..6cee3bc --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=260, + by_epoch=True, + begin=40, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_400e_in1k.py b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_400e_in1k.py new file mode 100644 index 0000000..fb78e2b --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_400e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=360, + by_epoch=True, + begin=40, + end=400, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_800e_in1k.py b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_800e_in1k.py new file mode 100644 index 0000000..f34e1da --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_800e_in1k.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=760, + by_epoch=True, + begin=40, + end=800, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=800) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_huge_p14_8xb512_amp_coslr_1600e_in1k.py b/mmpretrain/configs/mae/mae_vit_huge_p14_8xb512_amp_coslr_1600e_in1k.py new file mode 100644 index 0000000..bc91ee0 --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_huge_p14_8xb512_amp_coslr_1600e_in1k.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model.update( + backbone=dict(type=MAEViT, arch='h', patch_size=14), + neck=dict( + type=MAEPretrainDecoder, + embed_dim=1280, + patch_size=14, + num_patches=256), + head=dict(patch_size=14)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=1560, + by_epoch=True, + begin=40, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_1600e_in1k.py b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_1600e_in1k.py new file mode 100644 index 0000000..ef0777a --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_1600e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model = dict( + backbone=dict(type=MAEViT, arch='l'), + neck=dict(type=MAEPretrainDecoder, embed_dim=1024)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=1560, + by_epoch=True, + begin=40, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=1600) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_300e_in1k.py b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_300e_in1k.py new file mode 100644 index 0000000..ea005e4 --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_300e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model = dict( + backbone=dict(type=MAEViT, arch='l'), + neck=dict(type=MAEPretrainDecoder, embed_dim=1024)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=260, + by_epoch=True, + begin=40, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_400e_in1k.py b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_400e_in1k.py new file mode 100644 index 0000000..6f73549 --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_400e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model = dict( + backbone=dict(type=MAEViT, arch='l'), + neck=dict(type=MAEPretrainDecoder, embed_dim=1024)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=360, + by_epoch=True, + begin=40, + end=400, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_800e_in1k.py b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_800e_in1k.py new file mode 100644 index 0000000..a0a5abd --- /dev/null +++ b/mmpretrain/configs/mae/mae_vit_large_p16_8xb512_amp_coslr_800e_in1k.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mae_vit_base_p16 import * + from .._base_.datasets.imagenet_bs512_mae import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper +from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR +from mmengine.runner.loops import EpochBasedTrainLoop +from torch.optim.adamw import AdamW + +# model settings +model = dict( + backbone=dict(type=MAEViT, arch='l'), + neck=dict(type=MAEPretrainDecoder, embed_dim=1024)) + +# optimizer wrapper +optim_wrapper = dict( + type=AmpOptimWrapper, + loss_scale='dynamic', + optimizer=dict( + type=AdamW, + lr=1.5e-4 * 4096 / 256, + betas=(0.9, 0.95), + weight_decay=0.05), + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.0001, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + T_max=760, + by_epoch=True, + begin=40, + end=800, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=800) +# only keeps the latest 3 checkpoints +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=1, max_keep_ckpts=3) + +randomness.update(seed=0, diff_rank_seed=True) + +# auto resume +resume = True + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=4096) diff --git a/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py b/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py new file mode 100644 index 0000000..79eec63 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs32_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.mobilenet_v2_1x import * + from .._base_.schedules.imagenet_bs256_epochstep import * diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py new file mode 100644 index 0000000..3f1bee1 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. + +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict(arch='large'), + head=dict(in_channels=960, mid_channels=[1280]), + )) +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py new file mode 100644 index 0000000..50e1ffc --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict( + arch='small_050', + norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)), + head=dict(in_channels=288), + )) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline))) + +val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline))) +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py new file mode 100644 index 0000000..c8c640c --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict( + arch='small_075', + norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)), + head=dict(in_channels=432), + )) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline))) +val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline))) +test_dataloader = val_dataloader + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py new file mode 100644 index 0000000..0c220a0 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.optim import RMSprop + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py new file mode 100644 index 0000000..0f91ee3 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.cifar10_bs16 import * + from .._base_.schedules.cifar10_bs128 import * + from .._base_.default_runtime import * + +from mmengine.optim import MultiStepLR + +# model settings +model.merge( + dict( + head=dict( + _delete_=True, + type=StackedLinearClsHead, + num_classes=10, + in_channels=576, + mid_channels=[1280], + act_cfg=dict(type=Hardswish), + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5)))) +# schedule settings +param_scheduler.merge( + dict( + type=MultiStepLR, + by_epoch=True, + milestones=[120, 170], + gamma=0.1, + )) + +train_cfg.merge(dict(by_epoch=True, max_epochs=200)) diff --git a/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py b/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py new file mode 100644 index 0000000..f16d248 --- /dev/null +++ b/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs32 import * + from .._base_.default_runtime import * + from .._base_.models.resnet18 import * + from .._base_.schedules.imagenet_bs256 import * diff --git a/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py b/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py new file mode 100644 index 0000000..09c738f --- /dev/null +++ b/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs32_simclr import * + from .._base_.schedules.imagenet_lars_coslr_200e import * + from .._base_.default_runtime import * + +from mmengine.hooks.checkpoint_hook import CheckpointHook +from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper + +from mmpretrain.engine.optimizers.lars import LARS +from mmpretrain.models.backbones.resnet import ResNet +from mmpretrain.models.heads.contrastive_head import ContrastiveHead +from mmpretrain.models.losses.cross_entropy_loss import CrossEntropyLoss +from mmpretrain.models.necks.nonlinear_neck import NonLinearNeck +from mmpretrain.models.selfsup.simclr import SimCLR + +# dataset settings +train_dataloader.merge(dict(batch_size=256)) + +# model settings +model = dict( + type=SimCLR, + backbone=dict( + type=ResNet, + depth=50, + norm_cfg=dict(type='SyncBN'), + zero_init_residual=True), + neck=dict( + type=NonLinearNeck, # SimCLR non-linear neck + in_channels=2048, + hid_channels=2048, + out_channels=128, + num_layers=2, + with_avg_pool=True), + head=dict( + type=ContrastiveHead, + loss=dict(type=CrossEntropyLoss), + temperature=0.1), +) + +# optimizer +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict(type=LARS, lr=4.8, momentum=0.9, weight_decay=1e-6), + paramwise_cfg=dict( + custom_keys={ + 'bn': dict(decay_mult=0, lars_exclude=True), + 'bias': dict(decay_mult=0, lars_exclude=True), + # bn layer in ResNet block downsample module + 'downsample.1': dict(decay_mult=0, lars_exclude=True) + })) + +# runtime settings +default_hooks.checkpoint = dict( + type=CheckpointHook, interval=10, max_keep_ckpts=3) diff --git a/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k.py new file mode 100644 index 0000000..09af3d0 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(img_size=224, drop_path_rate=0.5, stage_cfgs=None), + head=dict( + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type=LabelSmoothLoss, + label_smooth_val=0.1, + mode='original', + loss_weight=0), + topk=None, + cal_acc=False), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k_384px.py new file mode 100644 index 0000000..aacdc32 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k_384px.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k.py new file mode 100644 index 0000000..b8fc279 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='large', img_size=224, stage_cfgs=None), + head=dict(in_channels=1536), +) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k_384px.py new file mode 100644 index 0000000..9a449aa --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k_384px.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='large'), + head=dict(in_channels=1536), +) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_large_8xb8_cub_384px.py b/mmpretrain/configs/swin_transformer/swin_large_8xb8_cub_384px.py new file mode 100644 index 0000000..2003cd3 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_large_8xb8_cub_384px.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.hooks import CheckpointHook, LoggerHook +from mmengine.model import PretrainedInit +from torch.optim.adamw import AdamW + +from mmpretrain.models import ImageClassifier + +with read_base(): + from .._base_.datasets.cub_bs8_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.cub_bs64 import * + +# model settings +checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth' # noqa + +model.update( + backbone=dict( + arch='large', + init_cfg=dict( + type=PretrainedInit, checkpoint=checkpoint, prefix='backbone')), + head=dict(num_classes=200, in_channels=1536)) + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + _delete_=True, + type=AdamW, + lr=5e-6, + weight_decay=0.0005, + eps=1e-8, + betas=(0.9, 0.999)), + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.absolute_pos_embed': dict(decay_mult=0.0), + '.relative_position_bias_table': dict(decay_mult=0.0) + }), + clip_grad=dict(max_norm=5.0), +) + +default_hooks = dict( + # log every 20 intervals + logger=dict(type=LoggerHook, interval=20), + # save last three checkpoints + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3)) diff --git a/mmpretrain/configs/swin_transformer/swin_small_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_small_16xb64_in1k.py new file mode 100644 index 0000000..5979252 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_small_16xb64_in1k.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='small', img_size=224, drop_path_rate=0.3, stage_cfgs=None), + head=dict( + in_channels=768, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type=LabelSmoothLoss, + label_smooth_val=0.1, + mode='original', + loss_weight=0), + topk=None, + cal_acc=False), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_tiny_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_tiny_16xb64_in1k.py new file mode 100644 index 0000000..733e1ef --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_tiny_16xb64_in1k.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='tiny', img_size=224, drop_path_rate=0.2, stage_cfgs=None), + head=dict( + in_channels=768, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type=LabelSmoothLoss, + label_smooth_val=0.1, + mode='original', + loss_weight=0), + topk=None, + cal_acc=False), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w12_8xb128_in21k_192px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w12_8xb128_in21k_192px.py new file mode 100644 index 0000000..1ecc436 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w12_8xb128_in21k_192px.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]), + head=dict(num_classes=21841), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# dataset settings +data_preprocessor = dict(num_classes=21841) + +_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop +_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge +_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_16xb64_in1k_256px.py new file mode 100644 index 0000000..103afb4 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=256, drop_path_rate=0.5, window_size=[16, 16, 16, 8]), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_in21k_pre_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_in21k_pre_16xb64_in1k_256px.py new file mode 100644 index 0000000..6588f50 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_in21k_pre_16xb64_in1k_256px.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=256, + window_size=[16, 16, 16, 8], + pretrained_window_sizes=[12, 12, 12, 6]), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w24_in21k_pre_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w24_in21k_pre_16xb64_in1k_384px.py new file mode 100644 index 0000000..118c085 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w24_in21k_pre_16xb64_in1k_384px.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + window_size=[24, 24, 24, 12], pretrained_window_sizes=[12, 12, 12, 6])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w8_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w8_16xb64_in1k_256px.py new file mode 100644 index 0000000..d40144c --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w8_16xb64_in1k_256px.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(img_size=256, drop_path_rate=0.5), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_large_w12_8xb128_in21k_192px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w12_8xb128_in21k_192px.py new file mode 100644 index 0000000..1ecc436 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w12_8xb128_in21k_192px.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]), + head=dict(num_classes=21841), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# dataset settings +data_preprocessor = dict(num_classes=21841) + +_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop +_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge +_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_large_w16_in21k_pre_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w16_in21k_pre_16xb64_in1k_256px.py new file mode 100644 index 0000000..0a1b59d --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w16_in21k_pre_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Only for evaluation +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='large', + img_size=256, + window_size=[16, 16, 16, 8], + pretrained_window_sizes=[12, 12, 12, 6]), + head=dict( + in_channels=1536, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5))) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_large_w24_in21k_pre_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w24_in21k_pre_16xb64_in1k_384px.py new file mode 100644 index 0000000..b20bcea --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w24_in21k_pre_16xb64_in1k_384px.py @@ -0,0 +1,24 @@ +# Only for evaluation +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='large', + img_size=384, + window_size=[24, 24, 24, 12], + pretrained_window_sizes=[12, 12, 12, 6]), + head=dict( + in_channels=1536, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5))) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_small_w16_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w16_16xb64_in1k_256px.py new file mode 100644 index 0000000..dfd15c3 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w16_16xb64_in1k_256px.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='small', + img_size=256, + drop_path_rate=0.3, + window_size=[16, 16, 16, 8]), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_small_w8_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w8_16xb64_in1k_256px.py new file mode 100644 index 0000000..bfec346 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w8_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='small', img_size=256, drop_path_rate=0.3), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w16_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w16_16xb64_in1k_256px.py new file mode 100644 index 0000000..f2fa160 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w16_16xb64_in1k_256px.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='tiny', + img_size=256, + drop_path_rate=0.2, + window_size=[16, 16, 16, 8]), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w8_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w8_16xb64_in1k_256px.py new file mode 100644 index 0000000..8cca2b3 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w8_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='tiny', img_size=256, drop_path_rate=0.2), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/vision_transformer/vit_base_p16_32xb128_mae_in1k.py b/mmpretrain/configs/vision_transformer/vit_base_p16_32xb128_mae_in1k.py new file mode 100644 index 0000000..18c2afd --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_base_p16_32xb128_mae_in1k.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit +from torch.optim import AdamW + +from mmpretrain.engine import EMAHook +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +model.update( + backbone=dict(drop_rate=0, drop_path_rate=0.1, init_cfg=None), + head=dict(loss=dict(mode='original')), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=.02), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# dataset settings +train_dataloader.update(batch_size=128) + +# schedule settings +optim_wrapper.update( + optimizer=dict( + type=AdamW, + lr=1e-4 * 4096 / 256, + weight_decay=0.3, + eps=1e-8, + betas=(0.9, 0.95)), + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0) + })) + +# runtime settings +custom_hooks = [dict(type=EMAHook, momentum=1e-4)] + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (32 GPUs) x (128 samples per GPU) +auto_scale_lr.update(base_batch_size=4096) diff --git a/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k.py b/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k.py new file mode 100644 index 0000000..8f128d1 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize_autoaug import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + head=dict(hidden_dim=3072), + train_cfg=dict(augments=dict(type=Mixup, alpha=0.2)), +) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k_384px.py b/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k_384px.py new file mode 100644 index 0000000..98e01f3 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_base_p16_64xb64_in1k_384px.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.datasets import (CenterCrop, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop, ResizeEdge) + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update(backbone=dict(img_size=384)) + +# dataset setting +data_preprocessor.update( + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=384, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=384, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=384), + dict(type=PackInputs), +] + +train_dataloader.update(dataset=dict(pipeline=train_pipeline)) +val_dataloader.update(dataset=dict(pipeline=test_pipeline)) +test_dataloader.update(dataset=dict(pipeline=test_pipeline)) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k.py b/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k.py new file mode 100644 index 0000000..3651c93 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize_autoaug import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(patch_size=32), + head=dict( + hidden_dim=3072, + topk=(1, 5), + ), + train_cfg=dict(augments=dict(type=Mixup, alpha=0.2)), +) + +model.head.loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k_384px.py b/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k_384px.py new file mode 100644 index 0000000..253740c --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_base_p32_64xb64_in1k_384px.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.datasets import (CenterCrop, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(img_size=384, patch_size=32), head=dict(topk=(1, 5))) + +model.head.loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# dataset setting +data_preprocessor.update( + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=384, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=384, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=384), + dict(type=PackInputs), +] + +train_dataloader.update(dataset=dict(pipeline=train_pipeline)) +val_dataloader.update(dataset=dict(pipeline=test_pipeline)) +test_dataloader.update(dataset=dict(pipeline=test_pipeline)) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k.py b/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k.py new file mode 100644 index 0000000..03f4a74 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize_autoaug import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(arch='l'), + head=dict( + hidden_dim=3072, + in_channels=1024, + topk=(1, 5), + ), + train_cfg=dict(augments=dict(type=Mixup, alpha=0.2)), +) + +model.head.loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k_384px.py b/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k_384px.py new file mode 100644 index 0000000..eba4bc4 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_large_p16_64xb64_in1k_384px.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.datasets import (CenterCrop, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(arch='l', img_size=384), + head=dict(in_channels=1024, topk=(1, 5))) + +model.head.loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# dataset setting +data_preprocessor.update( + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=384, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=384, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=384), + dict(type=PackInputs), +] + +train_dataloader.update(dataset=dict(pipeline=train_pipeline)) +val_dataloader.update(dataset=dict(pipeline=test_pipeline)) +test_dataloader.update(dataset=dict(pipeline=test_pipeline)) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k.py b/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k.py new file mode 100644 index 0000000..73dae6e --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize_autoaug import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(arch='l', patch_size=32), + head=dict( + hidden_dim=3072, + in_channels=1024, + topk=(1, 5), + ), + train_cfg=dict(augments=dict(type=Mixup, alpha=0.2)), +) + +loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k_384px.py b/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k_384px.py new file mode 100644 index 0000000..82e1619 --- /dev/null +++ b/mmpretrain/configs/vision_transformer/vit_large_p32_64xb64_in1k_384px.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.datasets import (CenterCrop, LoadImageFromFile, PackInputs, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.vit_base_p16 import * + from .._base_.schedules.imagenet_bs4096_adamw import * + +# model setting +model.update( + backbone=dict(arch='l', img_size=384, patch_size=32), + head=dict(in_channels=1024, topk=(1, 5))) + +model.head.loss = dict(type=CrossEntropyLoss, loss_weight=1.0) + +# dataset setting +data_preprocessor.update( + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=384, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=384, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=384), + dict(type=PackInputs), +] + +train_dataloader.update(dataset=dict(pipeline=train_pipeline)) +val_dataloader.update(dataset=dict(pipeline=test_pipeline)) +test_dataloader.update(dataset=dict(pipeline=test_pipeline)) + +# schedule setting +optim_wrapper.update(clip_grad=dict(max_norm=1.0)) diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py new file mode 100644 index 0000000..e621e15 --- /dev/null +++ b/mmpretrain/datasets/__init__.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.utils.dependency import WITH_MULTIMODAL +from .base_dataset import BaseDataset +from .builder import build_dataset +from .caltech101 import Caltech101 +from .cifar import CIFAR10, CIFAR100 +from .cub import CUB +from .custom import CustomDataset +from .dataset_wrappers import KFoldDataset +from .dtd import DTD +from .fgvcaircraft import FGVCAircraft +from .flowers102 import Flowers102 +from .food101 import Food101 +from .imagenet import ImageNet, ImageNet21k +from .inshop import InShop +from .mnist import MNIST, FashionMNIST +from .multi_label import MultiLabelDataset +from .multi_task import MultiTaskDataset +from .nlvr2 import NLVR2 +from .oxfordiiitpet import OxfordIIITPet +from .places205 import Places205 +from .samplers import * # noqa: F401,F403 +from .stanfordcars import StanfordCars +from .sun397 import SUN397 +from .transforms import * # noqa: F401,F403 +from .voc import VOC + +__all__ = [ + 'BaseDataset', 'CIFAR10', 'CIFAR100', 'CUB', 'Caltech101', 'CustomDataset', + 'DTD', 'FGVCAircraft', 'FashionMNIST', 'Flowers102', 'Food101', 'ImageNet', + 'ImageNet21k', 'InShop', 'KFoldDataset', 'MNIST', 'MultiLabelDataset', + 'MultiTaskDataset', 'NLVR2', 'OxfordIIITPet', 'Places205', 'SUN397', + 'StanfordCars', 'VOC', 'build_dataset' +] + +if WITH_MULTIMODAL: + from .coco_caption import COCOCaption + from .coco_retrieval import COCORetrieval + from .coco_vqa import COCOVQA + from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA + from .flickr30k_caption import Flickr30kCaption + from .flickr30k_retrieval import Flickr30kRetrieval + from .gqa_dataset import GQA + from .iconqa import IconQA + from .infographic_vqa import InfographicVQA + from .minigpt4_dataset import MiniGPT4Dataset + from .nocaps import NoCaps + from .ocr_vqa import OCRVQA + from .refcoco import RefCOCO + from .scienceqa import ScienceQA + from .textvqa import TextVQA + from .visual_genome import VisualGenomeQA + from .vizwiz import VizWiz + from .vsr import VSR + + __all__.extend([ + 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', + 'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval', + 'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', + 'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA', + 'MiniGPT4Dataset' + ]) diff --git a/mmpretrain/datasets/base_dataset.py b/mmpretrain/datasets/base_dataset.py new file mode 100644 index 0000000..dffdf04 --- /dev/null +++ b/mmpretrain/datasets/base_dataset.py @@ -0,0 +1,219 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from os import PathLike +from typing import List, Optional, Sequence, Union + +import mmengine +import numpy as np +from mmengine.dataset import BaseDataset as _BaseDataset + +from mmpretrain.registry import DATASETS, TRANSFORMS + + +def expanduser(path): + """Expand ~ and ~user constructions. + + If user or $HOME is unknown, do nothing. + """ + if isinstance(path, (str, PathLike)): + return osp.expanduser(path) + else: + return path + + +@DATASETS.register_module() +class BaseDataset(_BaseDataset): + """Base dataset for image classification task. + + This dataset support annotation file in `OpenMMLab 2.0 style annotation + format`. + + .. _OpenMMLab 2.0 style annotation format: + https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md + + Comparing with the :class:`mmengine.BaseDataset`, this class implemented + several useful methods. + + Args: + ann_file (str): Annotation file path. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None, which means using all ``data_infos``. + serialize_data (bool): Whether to hold memory using serialized objects, + when enabled, data loader workers can use shared RAM from master + process instead of making a copy. Defaults to True. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + test_mode (bool, optional): ``test_mode=True`` means in test phase, + an error will be raised when getting an item fails, ``test_mode=False`` + means in training phase, another item will be returned randomly. + Defaults to False. + lazy_init (bool): Whether to load annotation during instantiation. + In some cases, such as visualization, only the meta information of + the dataset is needed, which is not necessary to load annotation + file. ``Basedataset`` can skip load annotations to save time by set + ``lazy_init=False``. Defaults to False. + max_refetch (int): If ``Basedataset.prepare_data`` get a None img. + The maximum extra number of cycles to get a valid image. + Defaults to 1000. + classes (str | Sequence[str], optional): Specify names of classes. + + - If is string, it should be a file path, and the every line of + the file is a name of a class. + - If is a sequence of string, every item is a name of class. + - If is None, use categories information in ``metainfo`` argument, + annotation file or the class attribute ``METAINFO``. + + Defaults to None. + """ # noqa: E501 + + def __init__(self, + ann_file: str, + metainfo: Optional[dict] = None, + data_root: str = '', + data_prefix: Union[str, dict] = '', + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: Sequence = (), + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + classes: Union[str, Sequence[str], None] = None): + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + + ann_file = expanduser(ann_file) + metainfo = self._compat_classes(metainfo, classes) + + transforms = [] + for transform in pipeline: + if isinstance(transform, dict): + transforms.append(TRANSFORMS.build(transform)) + else: + transforms.append(transform) + + super().__init__( + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=transforms, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch) + + @property + def img_prefix(self): + """The prefix of images.""" + return self.data_prefix['img_path'] + + @property + def CLASSES(self): + """Return all categories names.""" + return self._metainfo.get('classes', None) + + @property + def class_to_idx(self): + """Map mapping class name to class index. + + Returns: + dict: mapping from class name to class index. + """ + + return {cat: i for i, cat in enumerate(self.CLASSES)} + + def get_gt_labels(self): + """Get all ground-truth labels (categories). + + Returns: + np.ndarray: categories for all images. + """ + + gt_labels = np.array( + [self.get_data_info(i)['gt_label'] for i in range(len(self))]) + return gt_labels + + def get_cat_ids(self, idx: int) -> List[int]: + """Get category id by index. + + Args: + idx (int): Index of data. + + Returns: + cat_ids (List[int]): Image category of specified index. + """ + + return [int(self.get_data_info(idx)['gt_label'])] + + def _compat_classes(self, metainfo, classes): + """Merge the old style ``classes`` arguments to ``metainfo``.""" + if isinstance(classes, str): + # take it as a file path + class_names = mmengine.list_from_file(expanduser(classes)) + elif isinstance(classes, (tuple, list)): + class_names = classes + elif classes is not None: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + if metainfo is None: + metainfo = {} + + if classes is not None: + metainfo = {'classes': tuple(class_names), **metainfo} + + return metainfo + + def full_init(self): + """Load annotation file and set ``BaseDataset._fully_initialized`` to + True.""" + super().full_init() + + # To support the standard OpenMMLab 2.0 annotation format. Generate + # metainfo in internal format from standard metainfo format. + if 'categories' in self._metainfo and 'classes' not in self._metainfo: + categories = sorted( + self._metainfo['categories'], key=lambda x: x['id']) + self._metainfo['classes'] = tuple( + [cat['category_name'] for cat in categories]) + + def __repr__(self): + """Print the basic information of the dataset. + + Returns: + str: Formatted string. + """ + head = 'Dataset ' + self.__class__.__name__ + body = [] + if self._fully_initialized: + body.append(f'Number of samples: \t{self.__len__()}') + else: + body.append("Haven't been initialized") + + if self.CLASSES is not None: + body.append(f'Number of categories: \t{len(self.CLASSES)}') + + body.extend(self.extra_repr()) + + if len(self.pipeline.transforms) > 0: + body.append('With transforms:') + for t in self.pipeline.transforms: + body.append(f' {t}') + + lines = [head] + [' ' * 4 + line for line in body] + return '\n'.join(lines) + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [] + body.append(f'Annotation file: \t{self.ann_file}') + body.append(f'Prefix of images: \t{self.img_prefix}') + return body diff --git a/mmpretrain/datasets/builder.py b/mmpretrain/datasets/builder.py new file mode 100644 index 0000000..dfa3872 --- /dev/null +++ b/mmpretrain/datasets/builder.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.registry import DATASETS + + +def build_dataset(cfg): + """Build dataset. + + Examples: + >>> from mmpretrain.datasets import build_dataset + >>> mnist_train = build_dataset( + ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False)) + >>> print(mnist_train) + Dataset MNIST + Number of samples: 60000 + Number of categories: 10 + Prefix of data: data/mnist/ + >>> mnist_test = build_dataset( + ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=True)) + >>> print(mnist_test) + Dataset MNIST + Number of samples: 10000 + Number of categories: 10 + Prefix of data: data/mnist/ + """ + return DATASETS.build(cfg) diff --git a/mmpretrain/datasets/caltech101.py b/mmpretrain/datasets/caltech101.py new file mode 100644 index 0000000..71e5de8 --- /dev/null +++ b/mmpretrain/datasets/caltech101.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import CALTECH101_CATEGORIES + + +@DATASETS.register_module() +class Caltech101(BaseDataset): + """The Caltech101 Dataset. + + Support the `Caltech101 `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Caltech101 dataset directory: :: + + caltech-101 + ├── 101_ObjectCategories + │ ├── class_x + │ │ ├── xx1.jpg + │ │ ├── xx2.jpg + │ │ └── ... + │ ├── class_y + │ │ ├── yy1.jpg + │ │ ├── yy2.jpg + │ │ └── ... + │ └── ... + ├── Annotations + │ ├── class_x + │ │ ├── xx1.mat + │ │ └── ... + │ └── ... + ├── meta + │ ├── train.txt + │ └── test.txt + └── .... + + Please note that since there is no official splitting for training and + test set, you can use the train.txt and text.txt provided by us or + create your own annotation files. Here is the download + `link `_ + for the annotations. + + Args: + data_root (str): The root directory for the Caltech101 dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import Caltech101 + >>> train_dataset = Caltech101(data_root='data/caltech-101', split='train') + >>> train_dataset + Dataset Caltech101 + Number of samples: 3060 + Number of categories: 102 + Root of dataset: data/caltech-101 + >>> test_dataset = Caltech101(data_root='data/caltech-101', split='test') + >>> test_dataset + Dataset Caltech101 + Number of samples: 6728 + Number of categories: 102 + Root of dataset: data/caltech-101 + """ # noqa: E501 + + METAINFO = {'classes': CALTECH101_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + + if split == 'train': + ann_file = self.backend.join_path('meta', 'train.txt') + else: + ann_file = self.backend.join_path('meta', 'test.txt') + + data_prefix = '101_ObjectCategories' + test_mode = split == 'test' + + super(Caltech101, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + + for pair in pairs: + path, gt_label = pair.split() + img_path = self.backend.join_path(self.img_prefix, path) + info = dict(img_path=img_path, gt_label=int(gt_label)) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/categories.py b/mmpretrain/datasets/categories.py new file mode 100644 index 0000000..9e75f79 --- /dev/null +++ b/mmpretrain/datasets/categories.py @@ -0,0 +1,1661 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Pre-defined categories names of various datasets. + +VOC2007_CATEGORIES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', + 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', + 'sofa', 'train', 'tvmonitor') + +CUB_CATEGORIES = ( + 'Black_footed_Albatross', 'Laysan_Albatross', 'Sooty_Albatross', + 'Groove_billed_Ani', 'Crested_Auklet', 'Least_Auklet', 'Parakeet_Auklet', + 'Rhinoceros_Auklet', 'Brewer_Blackbird', 'Red_winged_Blackbird', + 'Rusty_Blackbird', 'Yellow_headed_Blackbird', 'Bobolink', 'Indigo_Bunting', + 'Lazuli_Bunting', 'Painted_Bunting', 'Cardinal', 'Spotted_Catbird', + 'Gray_Catbird', 'Yellow_breasted_Chat', 'Eastern_Towhee', + 'Chuck_will_Widow', 'Brandt_Cormorant', 'Red_faced_Cormorant', + 'Pelagic_Cormorant', 'Bronzed_Cowbird', 'Shiny_Cowbird', 'Brown_Creeper', + 'American_Crow', 'Fish_Crow', 'Black_billed_Cuckoo', 'Mangrove_Cuckoo', + 'Yellow_billed_Cuckoo', 'Gray_crowned_Rosy_Finch', 'Purple_Finch', + 'Northern_Flicker', 'Acadian_Flycatcher', 'Great_Crested_Flycatcher', + 'Least_Flycatcher', 'Olive_sided_Flycatcher', 'Scissor_tailed_Flycatcher', + 'Vermilion_Flycatcher', 'Yellow_bellied_Flycatcher', 'Frigatebird', + 'Northern_Fulmar', 'Gadwall', 'American_Goldfinch', 'European_Goldfinch', + 'Boat_tailed_Grackle', 'Eared_Grebe', 'Horned_Grebe', 'Pied_billed_Grebe', + 'Western_Grebe', 'Blue_Grosbeak', 'Evening_Grosbeak', 'Pine_Grosbeak', + 'Rose_breasted_Grosbeak', 'Pigeon_Guillemot', 'California_Gull', + 'Glaucous_winged_Gull', 'Heermann_Gull', 'Herring_Gull', 'Ivory_Gull', + 'Ring_billed_Gull', 'Slaty_backed_Gull', 'Western_Gull', + 'Anna_Hummingbird', 'Ruby_throated_Hummingbird', 'Rufous_Hummingbird', + 'Green_Violetear', 'Long_tailed_Jaeger', 'Pomarine_Jaeger', 'Blue_Jay', + 'Florida_Jay', 'Green_Jay', 'Dark_eyed_Junco', 'Tropical_Kingbird', + 'Gray_Kingbird', 'Belted_Kingfisher', 'Green_Kingfisher', + 'Pied_Kingfisher', 'Ringed_Kingfisher', 'White_breasted_Kingfisher', + 'Red_legged_Kittiwake', 'Horned_Lark', 'Pacific_Loon', 'Mallard', + 'Western_Meadowlark', 'Hooded_Merganser', 'Red_breasted_Merganser', + 'Mockingbird', 'Nighthawk', 'Clark_Nutcracker', 'White_breasted_Nuthatch', + 'Baltimore_Oriole', 'Hooded_Oriole', 'Orchard_Oriole', 'Scott_Oriole', + 'Ovenbird', 'Brown_Pelican', 'White_Pelican', 'Western_Wood_Pewee', + 'Sayornis', 'American_Pipit', 'Whip_poor_Will', 'Horned_Puffin', + 'Common_Raven', 'White_necked_Raven', 'American_Redstart', 'Geococcyx', + 'Loggerhead_Shrike', 'Great_Grey_Shrike', 'Baird_Sparrow', + 'Black_throated_Sparrow', 'Brewer_Sparrow', 'Chipping_Sparrow', + 'Clay_colored_Sparrow', 'House_Sparrow', 'Field_Sparrow', 'Fox_Sparrow', + 'Grasshopper_Sparrow', 'Harris_Sparrow', 'Henslow_Sparrow', + 'Le_Conte_Sparrow', 'Lincoln_Sparrow', 'Nelson_Sharp_tailed_Sparrow', + 'Savannah_Sparrow', 'Seaside_Sparrow', 'Song_Sparrow', 'Tree_Sparrow', + 'Vesper_Sparrow', 'White_crowned_Sparrow', 'White_throated_Sparrow', + 'Cape_Glossy_Starling', 'Bank_Swallow', 'Barn_Swallow', 'Cliff_Swallow', + 'Tree_Swallow', 'Scarlet_Tanager', 'Summer_Tanager', 'Artic_Tern', + 'Black_Tern', 'Caspian_Tern', 'Common_Tern', 'Elegant_Tern', + 'Forsters_Tern', 'Least_Tern', 'Green_tailed_Towhee', 'Brown_Thrasher', + 'Sage_Thrasher', 'Black_capped_Vireo', 'Blue_headed_Vireo', + 'Philadelphia_Vireo', 'Red_eyed_Vireo', 'Warbling_Vireo', + 'White_eyed_Vireo', 'Yellow_throated_Vireo', 'Bay_breasted_Warbler', + 'Black_and_white_Warbler', 'Black_throated_Blue_Warbler', + 'Blue_winged_Warbler', 'Canada_Warbler', 'Cape_May_Warbler', + 'Cerulean_Warbler', 'Chestnut_sided_Warbler', 'Golden_winged_Warbler', + 'Hooded_Warbler', 'Kentucky_Warbler', 'Magnolia_Warbler', + 'Mourning_Warbler', 'Myrtle_Warbler', 'Nashville_Warbler', + 'Orange_crowned_Warbler', 'Palm_Warbler', 'Pine_Warbler', + 'Prairie_Warbler', 'Prothonotary_Warbler', 'Swainson_Warbler', + 'Tennessee_Warbler', 'Wilson_Warbler', 'Worm_eating_Warbler', + 'Yellow_Warbler', 'Northern_Waterthrush', 'Louisiana_Waterthrush', + 'Bohemian_Waxwing', 'Cedar_Waxwing', 'American_Three_toed_Woodpecker', + 'Pileated_Woodpecker', 'Red_bellied_Woodpecker', 'Red_cockaded_Woodpecker', + 'Red_headed_Woodpecker', 'Downy_Woodpecker', 'Bewick_Wren', 'Cactus_Wren', + 'Carolina_Wren', 'House_Wren', 'Marsh_Wren', 'Rock_Wren', 'Winter_Wren', + 'Common_Yellowthroat') + +IMAGENET_CATEGORIES = ( + 'tench, Tinca tinca', + 'goldfish, Carassius auratus', + 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', # noqa: E501 + 'tiger shark, Galeocerdo cuvieri', + 'hammerhead, hammerhead shark', + 'electric ray, crampfish, numbfish, torpedo', + 'stingray', + 'cock', + 'hen', + 'ostrich, Struthio camelus', + 'brambling, Fringilla montifringilla', + 'goldfinch, Carduelis carduelis', + 'house finch, linnet, Carpodacus mexicanus', + 'junco, snowbird', + 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', + 'robin, American robin, Turdus migratorius', + 'bulbul', + 'jay', + 'magpie', + 'chickadee', + 'water ouzel, dipper', + 'kite', + 'bald eagle, American eagle, Haliaeetus leucocephalus', + 'vulture', + 'great grey owl, great gray owl, Strix nebulosa', + 'European fire salamander, Salamandra salamandra', + 'common newt, Triturus vulgaris', + 'eft', + 'spotted salamander, Ambystoma maculatum', + 'axolotl, mud puppy, Ambystoma mexicanum', + 'bullfrog, Rana catesbeiana', + 'tree frog, tree-frog', + 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', + 'loggerhead, loggerhead turtle, Caretta caretta', + 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', # noqa: E501 + 'mud turtle', + 'terrapin', + 'box turtle, box tortoise', + 'banded gecko', + 'common iguana, iguana, Iguana iguana', + 'American chameleon, anole, Anolis carolinensis', + 'whiptail, whiptail lizard', + 'agama', + 'frilled lizard, Chlamydosaurus kingi', + 'alligator lizard', + 'Gila monster, Heloderma suspectum', + 'green lizard, Lacerta viridis', + 'African chameleon, Chamaeleo chamaeleon', + 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', # noqa: E501 + 'African crocodile, Nile crocodile, Crocodylus niloticus', + 'American alligator, Alligator mississipiensis', + 'triceratops', + 'thunder snake, worm snake, Carphophis amoenus', + 'ringneck snake, ring-necked snake, ring snake', + 'hognose snake, puff adder, sand viper', + 'green snake, grass snake', + 'king snake, kingsnake', + 'garter snake, grass snake', + 'water snake', + 'vine snake', + 'night snake, Hypsiglena torquata', + 'boa constrictor, Constrictor constrictor', + 'rock python, rock snake, Python sebae', + 'Indian cobra, Naja naja', + 'green mamba', + 'sea snake', + 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', + 'diamondback, diamondback rattlesnake, Crotalus adamanteus', + 'sidewinder, horned rattlesnake, Crotalus cerastes', + 'trilobite', + 'harvestman, daddy longlegs, Phalangium opilio', + 'scorpion', + 'black and gold garden spider, Argiope aurantia', + 'barn spider, Araneus cavaticus', + 'garden spider, Aranea diademata', + 'black widow, Latrodectus mactans', + 'tarantula', + 'wolf spider, hunting spider', + 'tick', + 'centipede', + 'black grouse', + 'ptarmigan', + 'ruffed grouse, partridge, Bonasa umbellus', + 'prairie chicken, prairie grouse, prairie fowl', + 'peacock', + 'quail', + 'partridge', + 'African grey, African gray, Psittacus erithacus', + 'macaw', + 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita', + 'lorikeet', + 'coucal', + 'bee eater', + 'hornbill', + 'hummingbird', + 'jacamar', + 'toucan', + 'drake', + 'red-breasted merganser, Mergus serrator', + 'goose', + 'black swan, Cygnus atratus', + 'tusker', + 'echidna, spiny anteater, anteater', + 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', # noqa: E501 + 'wallaby, brush kangaroo', + 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', # noqa: E501 + 'wombat', + 'jellyfish', + 'sea anemone, anemone', + 'brain coral', + 'flatworm, platyhelminth', + 'nematode, nematode worm, roundworm', + 'conch', + 'snail', + 'slug', + 'sea slug, nudibranch', + 'chiton, coat-of-mail shell, sea cradle, polyplacophore', + 'chambered nautilus, pearly nautilus, nautilus', + 'Dungeness crab, Cancer magister', + 'rock crab, Cancer irroratus', + 'fiddler crab', + 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', # noqa: E501 + 'American lobster, Northern lobster, Maine lobster, Homarus americanus', # noqa: E501 + 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', # noqa: E501 + 'crayfish, crawfish, crawdad, crawdaddy', + 'hermit crab', + 'isopod', + 'white stork, Ciconia ciconia', + 'black stork, Ciconia nigra', + 'spoonbill', + 'flamingo', + 'little blue heron, Egretta caerulea', + 'American egret, great white heron, Egretta albus', + 'bittern', + 'crane', + 'limpkin, Aramus pictus', + 'European gallinule, Porphyrio porphyrio', + 'American coot, marsh hen, mud hen, water hen, Fulica americana', + 'bustard', + 'ruddy turnstone, Arenaria interpres', + 'red-backed sandpiper, dunlin, Erolia alpina', + 'redshank, Tringa totanus', + 'dowitcher', + 'oystercatcher, oyster catcher', + 'pelican', + 'king penguin, Aptenodytes patagonica', + 'albatross, mollymawk', + 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', # noqa: E501 + 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', + 'dugong, Dugong dugon', + 'sea lion', + 'Chihuahua', + 'Japanese spaniel', + 'Maltese dog, Maltese terrier, Maltese', + 'Pekinese, Pekingese, Peke', + 'Shih-Tzu', + 'Blenheim spaniel', + 'papillon', + 'toy terrier', + 'Rhodesian ridgeback', + 'Afghan hound, Afghan', + 'basset, basset hound', + 'beagle', + 'bloodhound, sleuthhound', + 'bluetick', + 'black-and-tan coonhound', + 'Walker hound, Walker foxhound', + 'English foxhound', + 'redbone', + 'borzoi, Russian wolfhound', + 'Irish wolfhound', + 'Italian greyhound', + 'whippet', + 'Ibizan hound, Ibizan Podenco', + 'Norwegian elkhound, elkhound', + 'otterhound, otter hound', + 'Saluki, gazelle hound', + 'Scottish deerhound, deerhound', + 'Weimaraner', + 'Staffordshire bullterrier, Staffordshire bull terrier', + 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', # noqa: E501 + 'Bedlington terrier', + 'Border terrier', + 'Kerry blue terrier', + 'Irish terrier', + 'Norfolk terrier', + 'Norwich terrier', + 'Yorkshire terrier', + 'wire-haired fox terrier', + 'Lakeland terrier', + 'Sealyham terrier, Sealyham', + 'Airedale, Airedale terrier', + 'cairn, cairn terrier', + 'Australian terrier', + 'Dandie Dinmont, Dandie Dinmont terrier', + 'Boston bull, Boston terrier', + 'miniature schnauzer', + 'giant schnauzer', + 'standard schnauzer', + 'Scotch terrier, Scottish terrier, Scottie', + 'Tibetan terrier, chrysanthemum dog', + 'silky terrier, Sydney silky', + 'soft-coated wheaten terrier', + 'West Highland white terrier', + 'Lhasa, Lhasa apso', + 'flat-coated retriever', + 'curly-coated retriever', + 'golden retriever', + 'Labrador retriever', + 'Chesapeake Bay retriever', + 'German short-haired pointer', + 'vizsla, Hungarian pointer', + 'English setter', + 'Irish setter, red setter', + 'Gordon setter', + 'Brittany spaniel', + 'clumber, clumber spaniel', + 'English springer, English springer spaniel', + 'Welsh springer spaniel', + 'cocker spaniel, English cocker spaniel, cocker', + 'Sussex spaniel', + 'Irish water spaniel', + 'kuvasz', + 'schipperke', + 'groenendael', + 'malinois', + 'briard', + 'kelpie', + 'komondor', + 'Old English sheepdog, bobtail', + 'Shetland sheepdog, Shetland sheep dog, Shetland', + 'collie', + 'Border collie', + 'Bouvier des Flandres, Bouviers des Flandres', + 'Rottweiler', + 'German shepherd, German shepherd dog, German police dog, alsatian', + 'Doberman, Doberman pinscher', + 'miniature pinscher', + 'Greater Swiss Mountain dog', + 'Bernese mountain dog', + 'Appenzeller', + 'EntleBucher', + 'boxer', + 'bull mastiff', + 'Tibetan mastiff', + 'French bulldog', + 'Great Dane', + 'Saint Bernard, St Bernard', + 'Eskimo dog, husky', + 'malamute, malemute, Alaskan malamute', + 'Siberian husky', + 'dalmatian, coach dog, carriage dog', + 'affenpinscher, monkey pinscher, monkey dog', + 'basenji', + 'pug, pug-dog', + 'Leonberg', + 'Newfoundland, Newfoundland dog', + 'Great Pyrenees', + 'Samoyed, Samoyede', + 'Pomeranian', + 'chow, chow chow', + 'keeshond', + 'Brabancon griffon', + 'Pembroke, Pembroke Welsh corgi', + 'Cardigan, Cardigan Welsh corgi', + 'toy poodle', + 'miniature poodle', + 'standard poodle', + 'Mexican hairless', + 'timber wolf, grey wolf, gray wolf, Canis lupus', + 'white wolf, Arctic wolf, Canis lupus tundrarum', + 'red wolf, maned wolf, Canis rufus, Canis niger', + 'coyote, prairie wolf, brush wolf, Canis latrans', + 'dingo, warrigal, warragal, Canis dingo', + 'dhole, Cuon alpinus', + 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', + 'hyena, hyaena', + 'red fox, Vulpes vulpes', + 'kit fox, Vulpes macrotis', + 'Arctic fox, white fox, Alopex lagopus', + 'grey fox, gray fox, Urocyon cinereoargenteus', + 'tabby, tabby cat', + 'tiger cat', + 'Persian cat', + 'Siamese cat, Siamese', + 'Egyptian cat', + 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', # noqa: E501 + 'lynx, catamount', + 'leopard, Panthera pardus', + 'snow leopard, ounce, Panthera uncia', + 'jaguar, panther, Panthera onca, Felis onca', + 'lion, king of beasts, Panthera leo', + 'tiger, Panthera tigris', + 'cheetah, chetah, Acinonyx jubatus', + 'brown bear, bruin, Ursus arctos', + 'American black bear, black bear, Ursus americanus, Euarctos americanus', # noqa: E501 + 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', + 'sloth bear, Melursus ursinus, Ursus ursinus', + 'mongoose', + 'meerkat, mierkat', + 'tiger beetle', + 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', + 'ground beetle, carabid beetle', + 'long-horned beetle, longicorn, longicorn beetle', + 'leaf beetle, chrysomelid', + 'dung beetle', + 'rhinoceros beetle', + 'weevil', + 'fly', + 'bee', + 'ant, emmet, pismire', + 'grasshopper, hopper', + 'cricket', + 'walking stick, walkingstick, stick insect', + 'cockroach, roach', + 'mantis, mantid', + 'cicada, cicala', + 'leafhopper', + 'lacewing, lacewing fly', + "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", # noqa: E501 + 'damselfly', + 'admiral', + 'ringlet, ringlet butterfly', + 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', + 'cabbage butterfly', + 'sulphur butterfly, sulfur butterfly', + 'lycaenid, lycaenid butterfly', + 'starfish, sea star', + 'sea urchin', + 'sea cucumber, holothurian', + 'wood rabbit, cottontail, cottontail rabbit', + 'hare', + 'Angora, Angora rabbit', + 'hamster', + 'porcupine, hedgehog', + 'fox squirrel, eastern fox squirrel, Sciurus niger', + 'marmot', + 'beaver', + 'guinea pig, Cavia cobaya', + 'sorrel', + 'zebra', + 'hog, pig, grunter, squealer, Sus scrofa', + 'wild boar, boar, Sus scrofa', + 'warthog', + 'hippopotamus, hippo, river horse, Hippopotamus amphibius', + 'ox', + 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', + 'bison', + 'ram, tup', + 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', # noqa: E501 + 'ibex, Capra ibex', + 'hartebeest', + 'impala, Aepyceros melampus', + 'gazelle', + 'Arabian camel, dromedary, Camelus dromedarius', + 'llama', + 'weasel', + 'mink', + 'polecat, fitch, foulmart, foumart, Mustela putorius', + 'black-footed ferret, ferret, Mustela nigripes', + 'otter', + 'skunk, polecat, wood pussy', + 'badger', + 'armadillo', + 'three-toed sloth, ai, Bradypus tridactylus', + 'orangutan, orang, orangutang, Pongo pygmaeus', + 'gorilla, Gorilla gorilla', + 'chimpanzee, chimp, Pan troglodytes', + 'gibbon, Hylobates lar', + 'siamang, Hylobates syndactylus, Symphalangus syndactylus', + 'guenon, guenon monkey', + 'patas, hussar monkey, Erythrocebus patas', + 'baboon', + 'macaque', + 'langur', + 'colobus, colobus monkey', + 'proboscis monkey, Nasalis larvatus', + 'marmoset', + 'capuchin, ringtail, Cebus capucinus', + 'howler monkey, howler', + 'titi, titi monkey', + 'spider monkey, Ateles geoffroyi', + 'squirrel monkey, Saimiri sciureus', + 'Madagascar cat, ring-tailed lemur, Lemur catta', + 'indri, indris, Indri indri, Indri brevicaudatus', + 'Indian elephant, Elephas maximus', + 'African elephant, Loxodonta africana', + 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', + 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', + 'barracouta, snoek', + 'eel', + 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', # noqa: E501 + 'rock beauty, Holocanthus tricolor', + 'anemone fish', + 'sturgeon', + 'gar, garfish, garpike, billfish, Lepisosteus osseus', + 'lionfish', + 'puffer, pufferfish, blowfish, globefish', + 'abacus', + 'abaya', + "academic gown, academic robe, judge's robe", + 'accordion, piano accordion, squeeze box', + 'acoustic guitar', + 'aircraft carrier, carrier, flattop, attack aircraft carrier', + 'airliner', + 'airship, dirigible', + 'altar', + 'ambulance', + 'amphibian, amphibious vehicle', + 'analog clock', + 'apiary, bee house', + 'apron', + 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', # noqa: E501 + 'assault rifle, assault gun', + 'backpack, back pack, knapsack, packsack, rucksack, haversack', + 'bakery, bakeshop, bakehouse', + 'balance beam, beam', + 'balloon', + 'ballpoint, ballpoint pen, ballpen, Biro', + 'Band Aid', + 'banjo', + 'bannister, banister, balustrade, balusters, handrail', + 'barbell', + 'barber chair', + 'barbershop', + 'barn', + 'barometer', + 'barrel, cask', + 'barrow, garden cart, lawn cart, wheelbarrow', + 'baseball', + 'basketball', + 'bassinet', + 'bassoon', + 'bathing cap, swimming cap', + 'bath towel', + 'bathtub, bathing tub, bath, tub', + 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', # noqa: E501 + 'beacon, lighthouse, beacon light, pharos', + 'beaker', + 'bearskin, busby, shako', + 'beer bottle', + 'beer glass', + 'bell cote, bell cot', + 'bib', + 'bicycle-built-for-two, tandem bicycle, tandem', + 'bikini, two-piece', + 'binder, ring-binder', + 'binoculars, field glasses, opera glasses', + 'birdhouse', + 'boathouse', + 'bobsled, bobsleigh, bob', + 'bolo tie, bolo, bola tie, bola', + 'bonnet, poke bonnet', + 'bookcase', + 'bookshop, bookstore, bookstall', + 'bottlecap', + 'bow', + 'bow tie, bow-tie, bowtie', + 'brass, memorial tablet, plaque', + 'brassiere, bra, bandeau', + 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', + 'breastplate, aegis, egis', + 'broom', + 'bucket, pail', + 'buckle', + 'bulletproof vest', + 'bullet train, bullet', + 'butcher shop, meat market', + 'cab, hack, taxi, taxicab', + 'caldron, cauldron', + 'candle, taper, wax light', + 'cannon', + 'canoe', + 'can opener, tin opener', + 'cardigan', + 'car mirror', + 'carousel, carrousel, merry-go-round, roundabout, whirligig', + "carpenter's kit, tool kit", + 'carton', + 'car wheel', + 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', # noqa: E501 + 'cassette', + 'cassette player', + 'castle', + 'catamaran', + 'CD player', + 'cello, violoncello', + 'cellular telephone, cellular phone, cellphone, cell, mobile phone', + 'chain', + 'chainlink fence', + 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', # noqa: E501 + 'chain saw, chainsaw', + 'chest', + 'chiffonier, commode', + 'chime, bell, gong', + 'china cabinet, china closet', + 'Christmas stocking', + 'church, church building', + 'cinema, movie theater, movie theatre, movie house, picture palace', + 'cleaver, meat cleaver, chopper', + 'cliff dwelling', + 'cloak', + 'clog, geta, patten, sabot', + 'cocktail shaker', + 'coffee mug', + 'coffeepot', + 'coil, spiral, volute, whorl, helix', + 'combination lock', + 'computer keyboard, keypad', + 'confectionery, confectionary, candy store', + 'container ship, containership, container vessel', + 'convertible', + 'corkscrew, bottle screw', + 'cornet, horn, trumpet, trump', + 'cowboy boot', + 'cowboy hat, ten-gallon hat', + 'cradle', + 'crane', + 'crash helmet', + 'crate', + 'crib, cot', + 'Crock Pot', + 'croquet ball', + 'crutch', + 'cuirass', + 'dam, dike, dyke', + 'desk', + 'desktop computer', + 'dial telephone, dial phone', + 'diaper, nappy, napkin', + 'digital clock', + 'digital watch', + 'dining table, board', + 'dishrag, dishcloth', + 'dishwasher, dish washer, dishwashing machine', + 'disk brake, disc brake', + 'dock, dockage, docking facility', + 'dogsled, dog sled, dog sleigh', + 'dome', + 'doormat, welcome mat', + 'drilling platform, offshore rig', + 'drum, membranophone, tympan', + 'drumstick', + 'dumbbell', + 'Dutch oven', + 'electric fan, blower', + 'electric guitar', + 'electric locomotive', + 'entertainment center', + 'envelope', + 'espresso maker', + 'face powder', + 'feather boa, boa', + 'file, file cabinet, filing cabinet', + 'fireboat', + 'fire engine, fire truck', + 'fire screen, fireguard', + 'flagpole, flagstaff', + 'flute, transverse flute', + 'folding chair', + 'football helmet', + 'forklift', + 'fountain', + 'fountain pen', + 'four-poster', + 'freight car', + 'French horn, horn', + 'frying pan, frypan, skillet', + 'fur coat', + 'garbage truck, dustcart', + 'gasmask, respirator, gas helmet', + 'gas pump, gasoline pump, petrol pump, island dispenser', + 'goblet', + 'go-kart', + 'golf ball', + 'golfcart, golf cart', + 'gondola', + 'gong, tam-tam', + 'gown', + 'grand piano, grand', + 'greenhouse, nursery, glasshouse', + 'grille, radiator grille', + 'grocery store, grocery, food market, market', + 'guillotine', + 'hair slide', + 'hair spray', + 'half track', + 'hammer', + 'hamper', + 'hand blower, blow dryer, blow drier, hair dryer, hair drier', + 'hand-held computer, hand-held microcomputer', + 'handkerchief, hankie, hanky, hankey', + 'hard disc, hard disk, fixed disk', + 'harmonica, mouth organ, harp, mouth harp', + 'harp', + 'harvester, reaper', + 'hatchet', + 'holster', + 'home theater, home theatre', + 'honeycomb', + 'hook, claw', + 'hoopskirt, crinoline', + 'horizontal bar, high bar', + 'horse cart, horse-cart', + 'hourglass', + 'iPod', + 'iron, smoothing iron', + "jack-o'-lantern", + 'jean, blue jean, denim', + 'jeep, landrover', + 'jersey, T-shirt, tee shirt', + 'jigsaw puzzle', + 'jinrikisha, ricksha, rickshaw', + 'joystick', + 'kimono', + 'knee pad', + 'knot', + 'lab coat, laboratory coat', + 'ladle', + 'lampshade, lamp shade', + 'laptop, laptop computer', + 'lawn mower, mower', + 'lens cap, lens cover', + 'letter opener, paper knife, paperknife', + 'library', + 'lifeboat', + 'lighter, light, igniter, ignitor', + 'limousine, limo', + 'liner, ocean liner', + 'lipstick, lip rouge', + 'Loafer', + 'lotion', + 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', # noqa: E501 + "loupe, jeweler's loupe", + 'lumbermill, sawmill', + 'magnetic compass', + 'mailbag, postbag', + 'mailbox, letter box', + 'maillot', + 'maillot, tank suit', + 'manhole cover', + 'maraca', + 'marimba, xylophone', + 'mask', + 'matchstick', + 'maypole', + 'maze, labyrinth', + 'measuring cup', + 'medicine chest, medicine cabinet', + 'megalith, megalithic structure', + 'microphone, mike', + 'microwave, microwave oven', + 'military uniform', + 'milk can', + 'minibus', + 'miniskirt, mini', + 'minivan', + 'missile', + 'mitten', + 'mixing bowl', + 'mobile home, manufactured home', + 'Model T', + 'modem', + 'monastery', + 'monitor', + 'moped', + 'mortar', + 'mortarboard', + 'mosque', + 'mosquito net', + 'motor scooter, scooter', + 'mountain bike, all-terrain bike, off-roader', + 'mountain tent', + 'mouse, computer mouse', + 'mousetrap', + 'moving van', + 'muzzle', + 'nail', + 'neck brace', + 'necklace', + 'nipple', + 'notebook, notebook computer', + 'obelisk', + 'oboe, hautboy, hautbois', + 'ocarina, sweet potato', + 'odometer, hodometer, mileometer, milometer', + 'oil filter', + 'organ, pipe organ', + 'oscilloscope, scope, cathode-ray oscilloscope, CRO', + 'overskirt', + 'oxcart', + 'oxygen mask', + 'packet', + 'paddle, boat paddle', + 'paddlewheel, paddle wheel', + 'padlock', + 'paintbrush', + "pajama, pyjama, pj's, jammies", + 'palace', + 'panpipe, pandean pipe, syrinx', + 'paper towel', + 'parachute, chute', + 'parallel bars, bars', + 'park bench', + 'parking meter', + 'passenger car, coach, carriage', + 'patio, terrace', + 'pay-phone, pay-station', + 'pedestal, plinth, footstall', + 'pencil box, pencil case', + 'pencil sharpener', + 'perfume, essence', + 'Petri dish', + 'photocopier', + 'pick, plectrum, plectron', + 'pickelhaube', + 'picket fence, paling', + 'pickup, pickup truck', + 'pier', + 'piggy bank, penny bank', + 'pill bottle', + 'pillow', + 'ping-pong ball', + 'pinwheel', + 'pirate, pirate ship', + 'pitcher, ewer', + "plane, carpenter's plane, woodworking plane", + 'planetarium', + 'plastic bag', + 'plate rack', + 'plow, plough', + "plunger, plumber's helper", + 'Polaroid camera, Polaroid Land camera', + 'pole', + 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', # noqa: E501 + 'poncho', + 'pool table, billiard table, snooker table', + 'pop bottle, soda bottle', + 'pot, flowerpot', + "potter's wheel", + 'power drill', + 'prayer rug, prayer mat', + 'printer', + 'prison, prison house', + 'projectile, missile', + 'projector', + 'puck, hockey puck', + 'punching bag, punch bag, punching ball, punchball', + 'purse', + 'quill, quill pen', + 'quilt, comforter, comfort, puff', + 'racer, race car, racing car', + 'racket, racquet', + 'radiator', + 'radio, wireless', + 'radio telescope, radio reflector', + 'rain barrel', + 'recreational vehicle, RV, R.V.', + 'reel', + 'reflex camera', + 'refrigerator, icebox', + 'remote control, remote', + 'restaurant, eating house, eating place, eatery', + 'revolver, six-gun, six-shooter', + 'rifle', + 'rocking chair, rocker', + 'rotisserie', + 'rubber eraser, rubber, pencil eraser', + 'rugby ball', + 'rule, ruler', + 'running shoe', + 'safe', + 'safety pin', + 'saltshaker, salt shaker', + 'sandal', + 'sarong', + 'sax, saxophone', + 'scabbard', + 'scale, weighing machine', + 'school bus', + 'schooner', + 'scoreboard', + 'screen, CRT screen', + 'screw', + 'screwdriver', + 'seat belt, seatbelt', + 'sewing machine', + 'shield, buckler', + 'shoe shop, shoe-shop, shoe store', + 'shoji', + 'shopping basket', + 'shopping cart', + 'shovel', + 'shower cap', + 'shower curtain', + 'ski', + 'ski mask', + 'sleeping bag', + 'slide rule, slipstick', + 'sliding door', + 'slot, one-armed bandit', + 'snorkel', + 'snowmobile', + 'snowplow, snowplough', + 'soap dispenser', + 'soccer ball', + 'sock', + 'solar dish, solar collector, solar furnace', + 'sombrero', + 'soup bowl', + 'space bar', + 'space heater', + 'space shuttle', + 'spatula', + 'speedboat', + "spider web, spider's web", + 'spindle', + 'sports car, sport car', + 'spotlight, spot', + 'stage', + 'steam locomotive', + 'steel arch bridge', + 'steel drum', + 'stethoscope', + 'stole', + 'stone wall', + 'stopwatch, stop watch', + 'stove', + 'strainer', + 'streetcar, tram, tramcar, trolley, trolley car', + 'stretcher', + 'studio couch, day bed', + 'stupa, tope', + 'submarine, pigboat, sub, U-boat', + 'suit, suit of clothes', + 'sundial', + 'sunglass', + 'sunglasses, dark glasses, shades', + 'sunscreen, sunblock, sun blocker', + 'suspension bridge', + 'swab, swob, mop', + 'sweatshirt', + 'swimming trunks, bathing trunks', + 'swing', + 'switch, electric switch, electrical switch', + 'syringe', + 'table lamp', + 'tank, army tank, armored combat vehicle, armoured combat vehicle', + 'tape player', + 'teapot', + 'teddy, teddy bear', + 'television, television system', + 'tennis ball', + 'thatch, thatched roof', + 'theater curtain, theatre curtain', + 'thimble', + 'thresher, thrasher, threshing machine', + 'throne', + 'tile roof', + 'toaster', + 'tobacco shop, tobacconist shop, tobacconist', + 'toilet seat', + 'torch', + 'totem pole', + 'tow truck, tow car, wrecker', + 'toyshop', + 'tractor', + 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', # noqa: E501 + 'tray', + 'trench coat', + 'tricycle, trike, velocipede', + 'trimaran', + 'tripod', + 'triumphal arch', + 'trolleybus, trolley coach, trackless trolley', + 'trombone', + 'tub, vat', + 'turnstile', + 'typewriter keyboard', + 'umbrella', + 'unicycle, monocycle', + 'upright, upright piano', + 'vacuum, vacuum cleaner', + 'vase', + 'vault', + 'velvet', + 'vending machine', + 'vestment', + 'viaduct', + 'violin, fiddle', + 'volleyball', + 'waffle iron', + 'wall clock', + 'wallet, billfold, notecase, pocketbook', + 'wardrobe, closet, press', + 'warplane, military plane', + 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', + 'washer, automatic washer, washing machine', + 'water bottle', + 'water jug', + 'water tower', + 'whiskey jug', + 'whistle', + 'wig', + 'window screen', + 'window shade', + 'Windsor tie', + 'wine bottle', + 'wing', + 'wok', + 'wooden spoon', + 'wool, woolen, woollen', + 'worm fence, snake fence, snake-rail fence, Virginia fence', + 'wreck', + 'yawl', + 'yurt', + 'web site, website, internet site, site', + 'comic book', + 'crossword puzzle, crossword', + 'street sign', + 'traffic light, traffic signal, stoplight', + 'book jacket, dust cover, dust jacket, dust wrapper', + 'menu', + 'plate', + 'guacamole', + 'consomme', + 'hot pot, hotpot', + 'trifle', + 'ice cream, icecream', + 'ice lolly, lolly, lollipop, popsicle', + 'French loaf', + 'bagel, beigel', + 'pretzel', + 'cheeseburger', + 'hotdog, hot dog, red hot', + 'mashed potato', + 'head cabbage', + 'broccoli', + 'cauliflower', + 'zucchini, courgette', + 'spaghetti squash', + 'acorn squash', + 'butternut squash', + 'cucumber, cuke', + 'artichoke, globe artichoke', + 'bell pepper', + 'cardoon', + 'mushroom', + 'Granny Smith', + 'strawberry', + 'orange', + 'lemon', + 'fig', + 'pineapple, ananas', + 'banana', + 'jackfruit, jak, jack', + 'custard apple', + 'pomegranate', + 'hay', + 'carbonara', + 'chocolate sauce, chocolate syrup', + 'dough', + 'meat loaf, meatloaf', + 'pizza, pizza pie', + 'potpie', + 'burrito', + 'red wine', + 'espresso', + 'cup', + 'eggnog', + 'alp', + 'bubble', + 'cliff, drop, drop-off', + 'coral reef', + 'geyser', + 'lakeside, lakeshore', + 'promontory, headland, head, foreland', + 'sandbar, sand bar', + 'seashore, coast, seacoast, sea-coast', + 'valley, vale', + 'volcano', + 'ballplayer, baseball player', + 'groom, bridegroom', + 'scuba diver', + 'rapeseed', + 'daisy', + "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", # noqa: E501 + 'corn', + 'acorn', + 'hip, rose hip, rosehip', + 'buckeye, horse chestnut, conker', + 'coral fungus', + 'agaric', + 'gyromitra', + 'stinkhorn, carrion fungus', + 'earthstar', + 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', # noqa: E501 + 'bolete', + 'ear, spike, capitulum', + 'toilet tissue, toilet paper, bathroom tissue') + +CIFAR10_CATEGORIES = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', + 'frog', 'horse', 'ship', 'truck') + +CIFAR100_CATEGORIES = ( + 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', + 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', + 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', + 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', + 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', + 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', + 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', + 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', + 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', + 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', + 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', + 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', + 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', + 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', + 'woman', 'worm') + +MNIST_CATEGORITES = ('0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', + '5 - five', '6 - six', '7 - seven', '8 - eight', + '9 - nine') + +FASHIONMNIST_CATEGORITES = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', + 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', + 'Ankle boot') + +PLACES205_CATEGORIES = ( + 'abbey', 'airport_terminal', 'alley', 'amphitheater', 'amusement_park', + 'aquarium', 'aqueduct', 'arch', 'art_gallery', 'art_studio', + 'assembly_line', 'attic', 'auditorium', 'apartment_building/outdoor', + 'badlands', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar', + 'baseball_field', 'basement', 'basilica', 'bayou', 'beauty_salon', + 'bedroom', 'boardwalk', 'boat_deck', 'bookstore', 'botanical_garden', + 'bowling_alley', 'boxing_ring', 'bridge', 'building_facade', + 'bus_interior', 'butchers_shop', 'butte', 'bakery/shop', 'cafeteria', + 'campsite', 'candy_store', 'canyon', 'castle', 'cemetery', 'chalet', + 'classroom', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop', + 'conference_center', 'conference_room', 'construction_site', 'corn_field', + 'corridor', 'cottage_garden', 'courthouse', 'courtyard', 'creek', + 'crevasse', 'crosswalk', 'cathedral/outdoor', 'church/outdoor', 'dam', + 'dining_room', 'dock', 'dorm_room', 'driveway', 'desert/sand', + 'desert/vegetation', 'dinette/home', 'doorway/outdoor', 'engine_room', + 'excavation', 'fairway', 'fire_escape', 'fire_station', 'food_court', + 'forest_path', 'forest_road', 'formal_garden', 'fountain', + 'field/cultivated', 'field/wild', 'galley', 'game_room', 'garbage_dump', + 'gas_station', 'gift_shop', 'golf_course', 'harbor', 'herb_garden', + 'highway', 'home_office', 'hospital', 'hospital_room', 'hot_spring', + 'hotel_room', 'hotel/outdoor', 'ice_cream_parlor', 'iceberg', 'igloo', + 'islet', 'ice_skating_rink/outdoor', 'inn/outdoor', 'jail_cell', 'kasbah', + 'kindergarden_classroom', 'kitchen', 'kitchenette', 'laundromat', + 'lighthouse', 'living_room', 'lobby', 'locker_room', 'mansion', 'marsh', + 'martial_arts_gym', 'mausoleum', 'medina', 'motel', 'mountain', + 'mountain_snowy', 'music_studio', 'market/outdoor', 'monastery/outdoor', + 'museum/indoor', 'nursery', 'ocean', 'office', 'office_building', + 'orchard', 'pagoda', 'palace', 'pantry', 'parking_lot', 'parlor', + 'pasture', 'patio', 'pavilion', 'phone_booth', 'picnic_area', 'playground', + 'plaza', 'pond', 'pulpit', 'racecourse', 'raft', 'railroad_track', + 'rainforest', 'reception', 'residential_neighborhood', 'restaurant', + 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'river', + 'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sandbar', 'schoolhouse', + 'sea_cliff', 'shed', 'shoe_shop', 'shopfront', 'shower', 'ski_resort', + 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'staircase', + 'supermarket', 'swamp', 'stadium/baseball', 'stadium/football', + 'stage/indoor', 'subway_station/platform', 'swimming_pool/outdoor', + 'television_studio', 'topiary_garden', 'tower', 'train_railway', + 'tree_farm', 'trench', 'temple/east_asia', 'temple/south_asia', + 'track/outdoor', 'train_station/platform', 'underwater/coral_reef', + 'valley', 'vegetable_garden', 'veranda', 'viaduct', 'volcano', + 'waiting_room', 'water_tower', 'watering_hole', 'wheat_field', 'wind_farm', + 'windmill', 'yard') + +OxfordIIITPet_CATEGORIES = ( + 'Abyssinian', 'american_bulldog', 'american_pit_bull_terrier', + 'basset_hound', 'beagle', 'Bengal', 'Birman', 'Bombay', 'boxer', + 'British_Shorthair', 'chihuahua', 'Egyptian_Mau', 'english_cocker_spaniel', + 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', + 'japanese_chin', 'keeshond', 'leonberger', 'Maine_Coon', + 'miniature_pinscher', 'newfoundland', 'Persian', 'pomeranian', 'pug', + 'Ragdoll', 'Russian_Blue', 'saint_bernard', 'samoyed', 'scottish_terrier', + 'shiba_inu', 'Siamese', 'Sphynx', 'staffordshire_bull_terrier', + 'wheaten_terrier', 'yorkshire_terrier') + +DTD_CATEGORIES = ('banded', 'blotchy', 'braided', 'bubbly', 'bumpy', + 'chequered', 'cobwebbed', 'cracked', 'crosshatched', + 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', + 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', + 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', + 'matted', 'meshed', 'paisley', 'perforated', 'pitted', + 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', + 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', + 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', + 'wrinkled', 'zigzagged') + +FGVCAIRCRAFT_CATEGORIES = ( + '707-320', '727-200', '737-200', '737-300', '737-400', '737-500', + '737-600', '737-700', '737-800', '737-900', '747-100', '747-200', + '747-300', '747-400', '757-200', '757-300', '767-200', '767-300', + '767-400', '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320', + 'A321', 'A330-200', 'A330-300', 'A340-200', 'A340-300', 'A340-500', + 'A340-600', 'A380', 'ATR-42', 'ATR-72', 'An-12', 'BAE 146-200', + 'BAE 146-300', 'BAE-125', 'Beechcraft 1900', 'Boeing 717', 'C-130', 'C-47', + 'CRJ-200', 'CRJ-700', 'CRJ-900', 'Cessna 172', 'Cessna 208', 'Cessna 525', + 'Cessna 560', 'Challenger 600', 'DC-10', 'DC-3', 'DC-6', 'DC-8', 'DC-9-30', + 'DH-82', 'DHC-1', 'DHC-6', 'DHC-8-100', 'DHC-8-300', 'DR-400', + 'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'ERJ 135', 'ERJ 145', + 'Embraer Legacy 600', 'Eurofighter Typhoon', 'F-16A/B', 'F/A-18', + 'Falcon 2000', 'Falcon 900', 'Fokker 100', 'Fokker 50', 'Fokker 70', + 'Global Express', 'Gulfstream IV', 'Gulfstream V', 'Hawk T1', 'Il-76', + 'L-1011', 'MD-11', 'MD-80', 'MD-87', 'MD-90', 'Metroliner', 'Model B200', + 'PA-28', 'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134', + 'Tu-154', 'Yak-42') + +STANFORDCARS_CATEGORIES = ( + 'AM General Hummer SUV 2000', 'Acura RL Sedan 2012', 'Acura TL Sedan 2012', + 'Acura TL Type-S 2008', 'Acura TSX Sedan 2012', + 'Acura Integra Type R 2001', 'Acura ZDX Hatchback 2012', + 'Aston Martin V8 Vantage Convertible 2012', + 'Aston Martin V8 Vantage Coupe 2012', + 'Aston Martin Virage Convertible 2012', 'Aston Martin Virage Coupe 2012', + 'Audi RS 4 Convertible 2008', 'Audi A5 Coupe 2012', 'Audi TTS Coupe 2012', + 'Audi R8 Coupe 2012', 'Audi V8 Sedan 1994', 'Audi 100 Sedan 1994', + 'Audi 100 Wagon 1994', 'Audi TT Hatchback 2011', 'Audi S6 Sedan 2011', + 'Audi S5 Convertible 2012', 'Audi S5 Coupe 2012', 'Audi S4 Sedan 2012', + 'Audi S4 Sedan 2007', 'Audi TT RS Coupe 2012', + 'BMW ActiveHybrid 5 Sedan 2012', 'BMW 1 Series Convertible 2012', + 'BMW 1 Series Coupe 2012', 'BMW 3 Series Sedan 2012', + 'BMW 3 Series Wagon 2012', 'BMW 6 Series Convertible 2007', + 'BMW X5 SUV 2007', 'BMW X6 SUV 2012', 'BMW M3 Coupe 2012', + 'BMW M5 Sedan 2010', 'BMW M6 Convertible 2010', 'BMW X3 SUV 2012', + 'BMW Z4 Convertible 2012', + 'Bentley Continental Supersports Conv. Convertible 2012', + 'Bentley Arnage Sedan 2009', 'Bentley Mulsanne Sedan 2011', + 'Bentley Continental GT Coupe 2012', 'Bentley Continental GT Coupe 2007', + 'Bentley Continental Flying Spur Sedan 2007', + 'Bugatti Veyron 16.4 Convertible 2009', 'Bugatti Veyron 16.4 Coupe 2009', + 'Buick Regal GS 2012', 'Buick Rainier SUV 2007', 'Buick Verano Sedan 2012', + 'Buick Enclave SUV 2012', 'Cadillac CTS-V Sedan 2012', + 'Cadillac SRX SUV 2012', 'Cadillac Escalade EXT Crew Cab 2007', + 'Chevrolet Silverado 1500 Hybrid Crew Cab 2012', + 'Chevrolet Corvette Convertible 2012', 'Chevrolet Corvette ZR1 2012', + 'Chevrolet Corvette Ron Fellows Edition Z06 2007', + 'Chevrolet Traverse SUV 2012', 'Chevrolet Camaro Convertible 2012', + 'Chevrolet HHR SS 2010', 'Chevrolet Impala Sedan 2007', + 'Chevrolet Tahoe Hybrid SUV 2012', 'Chevrolet Sonic Sedan 2012', + 'Chevrolet Express Cargo Van 2007', 'Chevrolet Avalanche Crew Cab 2012', + 'Chevrolet Cobalt SS 2010', 'Chevrolet Malibu Hybrid Sedan 2010', + 'Chevrolet TrailBlazer SS 2009', + 'Chevrolet Silverado 2500HD Regular Cab 2012', + 'Chevrolet Silverado 1500 Classic Extended Cab 2007', + 'Chevrolet Express Van 2007', 'Chevrolet Monte Carlo Coupe 2007', + 'Chevrolet Malibu Sedan 2007', + 'Chevrolet Silverado 1500 Extended Cab 2012', + 'Chevrolet Silverado 1500 Regular Cab 2012', 'Chrysler Aspen SUV 2009', + 'Chrysler Sebring Convertible 2010', + 'Chrysler Town and Country Minivan 2012', 'Chrysler 300 SRT-8 2010', + 'Chrysler Crossfire Convertible 2008', + 'Chrysler PT Cruiser Convertible 2008', 'Daewoo Nubira Wagon 2002', + 'Dodge Caliber Wagon 2012', 'Dodge Caliber Wagon 2007', + 'Dodge Caravan Minivan 1997', 'Dodge Ram Pickup 3500 Crew Cab 2010', + 'Dodge Ram Pickup 3500 Quad Cab 2009', 'Dodge Sprinter Cargo Van 2009', + 'Dodge Journey SUV 2012', 'Dodge Dakota Crew Cab 2010', + 'Dodge Dakota Club Cab 2007', 'Dodge Magnum Wagon 2008', + 'Dodge Challenger SRT8 2011', 'Dodge Durango SUV 2012', + 'Dodge Durango SUV 2007', 'Dodge Charger Sedan 2012', + 'Dodge Charger SRT-8 2009', 'Eagle Talon Hatchback 1998', + 'FIAT 500 Abarth 2012', 'FIAT 500 Convertible 2012', + 'Ferrari FF Coupe 2012', 'Ferrari California Convertible 2012', + 'Ferrari 458 Italia Convertible 2012', 'Ferrari 458 Italia Coupe 2012', + 'Fisker Karma Sedan 2012', 'Ford F-450 Super Duty Crew Cab 2012', + 'Ford Mustang Convertible 2007', 'Ford Freestar Minivan 2007', + 'Ford Expedition EL SUV 2009', 'Ford Edge SUV 2012', + 'Ford Ranger SuperCab 2011', 'Ford GT Coupe 2006', + 'Ford F-150 Regular Cab 2012', 'Ford F-150 Regular Cab 2007', + 'Ford Focus Sedan 2007', 'Ford E-Series Wagon Van 2012', + 'Ford Fiesta Sedan 2012', 'GMC Terrain SUV 2012', 'GMC Savana Van 2012', + 'GMC Yukon Hybrid SUV 2012', 'GMC Acadia SUV 2012', + 'GMC Canyon Extended Cab 2012', 'Geo Metro Convertible 1993', + 'HUMMER H3T Crew Cab 2010', 'HUMMER H2 SUT Crew Cab 2009', + 'Honda Odyssey Minivan 2012', 'Honda Odyssey Minivan 2007', + 'Honda Accord Coupe 2012', 'Honda Accord Sedan 2012', + 'Hyundai Veloster Hatchback 2012', 'Hyundai Santa Fe SUV 2012', + 'Hyundai Tucson SUV 2012', 'Hyundai Veracruz SUV 2012', + 'Hyundai Sonata Hybrid Sedan 2012', 'Hyundai Elantra Sedan 2007', + 'Hyundai Accent Sedan 2012', 'Hyundai Genesis Sedan 2012', + 'Hyundai Sonata Sedan 2012', 'Hyundai Elantra Touring Hatchback 2012', + 'Hyundai Azera Sedan 2012', 'Infiniti G Coupe IPL 2012', + 'Infiniti QX56 SUV 2011', 'Isuzu Ascender SUV 2008', 'Jaguar XK XKR 2012', + 'Jeep Patriot SUV 2012', 'Jeep Wrangler SUV 2012', 'Jeep Liberty SUV 2012', + 'Jeep Grand Cherokee SUV 2012', 'Jeep Compass SUV 2012', + 'Lamborghini Reventon Coupe 2008', 'Lamborghini Aventador Coupe 2012', + 'Lamborghini Gallardo LP 570-4 Superleggera 2012', + 'Lamborghini Diablo Coupe 2001', 'Land Rover Range Rover SUV 2012', + 'Land Rover LR2 SUV 2012', 'Lincoln Town Car Sedan 2011', + 'MINI Cooper Roadster Convertible 2012', + 'Maybach Landaulet Convertible 2012', 'Mazda Tribute SUV 2011', + 'McLaren MP4-12C Coupe 2012', 'Mercedes-Benz 300-Class Convertible 1993', + 'Mercedes-Benz C-Class Sedan 2012', 'Mercedes-Benz SL-Class Coupe 2009', + 'Mercedes-Benz E-Class Sedan 2012', 'Mercedes-Benz S-Class Sedan 2012', + 'Mercedes-Benz Sprinter Van 2012', 'Mitsubishi Lancer Sedan 2012', + 'Nissan Leaf Hatchback 2012', 'Nissan NV Passenger Van 2012', + 'Nissan Juke Hatchback 2012', 'Nissan 240SX Coupe 1998', + 'Plymouth Neon Coupe 1999', 'Porsche Panamera Sedan 2012', + 'Ram C/V Cargo Van Minivan 2012', + 'Rolls-Royce Phantom Drophead Coupe Convertible 2012', + 'Rolls-Royce Ghost Sedan 2012', 'Rolls-Royce Phantom Sedan 2012', + 'Scion xD Hatchback 2012', 'Spyker C8 Convertible 2009', + 'Spyker C8 Coupe 2009', 'Suzuki Aerio Sedan 2007', + 'Suzuki Kizashi Sedan 2012', 'Suzuki SX4 Hatchback 2012', + 'Suzuki SX4 Sedan 2012', 'Tesla Model S Sedan 2012', + 'Toyota Sequoia SUV 2012', 'Toyota Camry Sedan 2012', + 'Toyota Corolla Sedan 2012', 'Toyota 4Runner SUV 2012', + 'Volkswagen Golf Hatchback 2012', 'Volkswagen Golf Hatchback 1991', + 'Volkswagen Beetle Hatchback 2012', 'Volvo C30 Hatchback 2012', + 'Volvo 240 Sedan 1993', 'Volvo XC90 SUV 2007', + 'smart fortwo Convertible 2012') + +SUN397_CATEGORIES = ( + 'abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater', + 'amusement_arcade', 'amusement_park', 'anechoic_chamber', + 'apartment_building_outdoor', 'apse_indoor', 'aquarium', 'aqueduct', + 'arch', 'archive', 'arrival_gate_outdoor', 'art_gallery', 'art_school', + 'art_studio', 'assembly_line', 'athletic_field_outdoor', 'atrium_public', + 'attic', 'auditorium', 'auto_factory', 'badlands', + 'badminton_court_indoor', 'baggage_claim', 'bakery_shop', + 'balcony_exterior', 'balcony_interior', 'ball_pit', 'ballroom', + 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor', + 'baseball_field', 'basement', 'basilica', 'basketball_court_outdoor', + 'bathroom', 'batters_box', 'bayou', 'bazaar_indoor', 'bazaar_outdoor', + 'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory', + 'bistro_indoor', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore', + 'booth_indoor', 'botanical_garden', 'bow_window_indoor', + 'bow_window_outdoor', 'bowling_alley', 'boxing_ring', 'brewery_indoor', + 'bridge', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior', + 'butchers_shop', 'butte', 'cabin_outdoor', 'cafeteria', 'campsite', + 'campus', 'canal_natural', 'canal_urban', 'candy_store', 'canyon', + 'car_interior_backseat', 'car_interior_frontseat', 'carrousel', + 'casino_indoor', 'castle', 'catacomb', 'cathedral_indoor', + 'cathedral_outdoor', 'cavern_indoor', 'cemetery', 'chalet', + 'cheese_factory', 'chemistry_lab', 'chicken_coop_indoor', + 'chicken_coop_outdoor', 'childs_room', 'church_indoor', 'church_outdoor', + 'classroom', 'clean_room', 'cliff', 'cloister_indoor', 'closet', + 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room', + 'conference_center', 'conference_room', 'construction_site', + 'control_room', 'control_tower_outdoor', 'corn_field', 'corral', + 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard', + 'covered_bridge_exterior', 'creek', 'crevasse', 'crosswalk', + 'cubicle_office', 'dam', 'delicatessen', 'dentists_office', 'desert_sand', + 'desert_vegetation', 'diner_indoor', 'diner_outdoor', 'dinette_home', + 'dinette_vehicle', 'dining_car', 'dining_room', 'discotheque', 'dock', + 'doorway_outdoor', 'dorm_room', 'driveway', 'driving_range_outdoor', + 'drugstore', 'electrical_substation', 'elevator_door', 'elevator_interior', + 'elevator_shaft', 'engine_room', 'escalator_indoor', 'excavation', + 'factory_indoor', 'fairway', 'fastfood_restaurant', 'field_cultivated', + 'field_wild', 'fire_escape', 'fire_station', 'firing_range_indoor', + 'fishpond', 'florist_shop_indoor', 'food_court', 'forest_broadleaf', + 'forest_needleleaf', 'forest_path', 'forest_road', 'formal_garden', + 'fountain', 'galley', 'game_room', 'garage_indoor', 'garbage_dump', + 'gas_station', 'gazebo_exterior', 'general_store_indoor', + 'general_store_outdoor', 'gift_shop', 'golf_course', 'greenhouse_indoor', + 'greenhouse_outdoor', 'gymnasium_indoor', 'hangar_indoor', + 'hangar_outdoor', 'harbor', 'hayfield', 'heliport', 'herb_garden', + 'highway', 'hill', 'home_office', 'hospital', 'hospital_room', + 'hot_spring', 'hot_tub_outdoor', 'hotel_outdoor', 'hotel_room', 'house', + 'hunting_lodge_outdoor', 'ice_cream_parlor', 'ice_floe', 'ice_shelf', + 'ice_skating_rink_indoor', 'ice_skating_rink_outdoor', 'iceberg', 'igloo', + 'industrial_area', 'inn_outdoor', 'islet', 'jacuzzi_indoor', 'jail_indoor', + 'jail_cell', 'jewelry_shop', 'kasbah', 'kennel_indoor', 'kennel_outdoor', + 'kindergarden_classroom', 'kitchen', 'kitchenette', 'labyrinth_outdoor', + 'lake_natural', 'landfill', 'landing_deck', 'laundromat', 'lecture_room', + 'library_indoor', 'library_outdoor', 'lido_deck_outdoor', 'lift_bridge', + 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber', + 'locker_room', 'mansion', 'manufactured_home', 'market_indoor', + 'market_outdoor', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina', + 'moat_water', 'monastery_outdoor', 'mosque_indoor', 'mosque_outdoor', + 'motel', 'mountain', 'mountain_snowy', 'movie_theater_indoor', + 'museum_indoor', 'music_store', 'music_studio', + 'nuclear_power_plant_outdoor', 'nursery', 'oast_house', + 'observatory_outdoor', 'ocean', 'office', 'office_building', + 'oil_refinery_outdoor', 'oilrig', 'operating_room', 'orchard', + 'outhouse_outdoor', 'pagoda', 'palace', 'pantry', 'park', + 'parking_garage_indoor', 'parking_garage_outdoor', 'parking_lot', 'parlor', + 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth', + 'physics_laboratory', 'picnic_area', 'pilothouse_indoor', + 'planetarium_outdoor', 'playground', 'playroom', 'plaza', 'podium_indoor', + 'podium_outdoor', 'pond', 'poolroom_establishment', 'poolroom_home', + 'power_plant_outdoor', 'promenade_deck', 'pub_indoor', 'pulpit', + 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track', + 'rainforest', 'reception', 'recreation_room', 'residential_neighborhood', + 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', + 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway', + 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room', + 'shed', 'shoe_shop', 'shopfront', 'shopping_mall_indoor', 'shower', + 'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper', + 'slum', 'snowfield', 'squash_court', 'stable', 'stadium_baseball', + 'stadium_football', 'stage_indoor', 'staircase', 'street', + 'subway_interior', 'subway_station_platform', 'supermarket', 'sushi_bar', + 'swamp', 'swimming_pool_indoor', 'swimming_pool_outdoor', + 'synagogue_indoor', 'synagogue_outdoor', 'television_studio', + 'temple_east_asia', 'temple_south_asia', 'tennis_court_indoor', + 'tennis_court_outdoor', 'tent_outdoor', 'theater_indoor_procenium', + 'theater_indoor_seats', 'thriftshop', 'throne_room', 'ticket_booth', + 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'track_outdoor', + 'train_railway', 'train_station_platform', 'tree_farm', 'tree_house', + 'trench', 'underwater_coral_reef', 'utility_room', 'valley', + 'van_interior', 'vegetable_garden', 'veranda', 'veterinarians_office', + 'viaduct', 'videostore', 'village', 'vineyard', 'volcano', + 'volleyball_court_indoor', 'volleyball_court_outdoor', 'waiting_room', + 'warehouse_indoor', 'water_tower', 'waterfall_block', 'waterfall_fan', + 'waterfall_plunge', 'watering_hole', 'wave', 'wet_bar', 'wheat_field', + 'wind_farm', 'windmill', 'wine_cellar_barrel_storage', + 'wine_cellar_bottle_storage', 'wrestling_ring_indoor', 'yard', + 'youth_hostel') + +CALTECH101_CATEGORIES = ( + 'BACKGROUND_Google', 'Faces', 'Faces_easy', 'Leopards', 'Motorbikes', + 'accordion', 'airplanes', 'anchor', 'ant', 'barrel', 'bass', 'beaver', + 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', + 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', + 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', + 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', + 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', + 'ewer', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', + 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', + 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', + 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', + 'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi', + 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', + 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', + 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', + 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', + 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', + 'wrench', 'yin_yang') + +FOOD101_CATEGORIES = ( + 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', + 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', + 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', + 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', + 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', + 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', + 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', + 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', + 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', + 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', + 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', + 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', + 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', + 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', + 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', + 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', + 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', + 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', + 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', + 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', + 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', + 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles') + +CIFAR100_CATEGORIES_CN = ( + '苹果', '水族馆鱼', '婴儿', '熊', '河狸', '床', '蜜蜂', '甲虫', '自行车', '瓶子', '碗', '小男孩', + '桥', '公共汽车', '蝴蝶', '骆驼', '易拉罐', '城堡', '毛毛虫', '牛', '椅子', '猩猩', '钟', '白云', + '蟑螂', '沙发', '螃蟹', '鳄鱼', '杯子', '恐龙', '海豚', '大象', '比目鱼', '森林', '狐狸', '小女孩', + '仓鼠', '屋子', '袋鼠', '键盘', '台灯', '割草机', '猎豹', '狮子', '蜥蜴', '龙虾', '男人', '枫树', + '摩托车', '山', '老鼠', '蘑菇', '橡树', '橙子橘子', '兰花', '水獭', '棕榈树', '梨', '皮卡车', '松树', + '田野', '盘子', '罂粟', '豪猪', '负鼠', '兔子', '浣熊', '鳐鱼', '公路', '火箭', '玫瑰', '大海', + '海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒', + '桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼', + '柳树', '狼', '女人', '蠕虫') + +IMAGENET_SIMPLE_CATEGORIES = ( + 'tench', 'goldfish', 'great white shark', 'tiger shark', + 'hammerhead shark', 'electric ray', 'stingray', 'rooster', 'hen', + 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', + 'indigo bunting', 'American robin', 'bulbul', 'jay', 'magpie', 'chickadee', + 'American dipper', 'kite (bird of prey)', 'bald eagle', 'vulture', + 'great grey owl', 'fire salamander', 'smooth newt', 'newt', + 'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog', + 'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle', + 'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'green iguana', + 'Carolina anole', 'desert grassland whiptail lizard', 'agama', + 'frilled-necked lizard', 'alligator lizard', 'Gila monster', + 'European green lizard', 'chameleon', 'Komodo dragon', 'Nile crocodile', + 'American alligator', 'triceratops', 'worm snake', 'ring-necked snake', + 'eastern hog-nosed snake', 'smooth green snake', 'kingsnake', + 'garter snake', 'water snake', 'vine snake', 'night snake', + 'boa constrictor', 'African rock python', 'Indian cobra', 'green mamba', + 'sea snake', 'Saharan horned viper', 'eastern diamondback rattlesnake', + 'sidewinder rattlesnake', 'trilobite', 'harvestman', 'scorpion', + 'yellow garden spider', 'barn spider', 'European garden spider', + 'southern black widow', 'tarantula', 'wolf spider', 'tick', 'centipede', + 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peafowl', + 'quail', 'partridge', 'african grey parrot', 'macaw', + 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill', + 'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser', + 'goose', 'black swan', 'tusker', 'echidna', 'platypus', 'wallaby', 'koala', + 'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm', + 'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton', + 'chambered nautilus', 'Dungeness crab', 'rock crab', 'fiddler crab', + 'red king crab', 'American lobster', 'spiny lobster', 'crayfish', + 'hermit crab', 'isopod', 'white stork', 'black stork', 'spoonbill', + 'flamingo', 'little blue heron', 'great egret', 'bittern bird', + 'crane bird', 'limpkin', 'common gallinule', 'American coot', 'bustard', + 'ruddy turnstone', 'dunlin', 'common redshank', 'dowitcher', + 'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale', + 'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin', + 'Maltese', 'Pekingese', 'Shih Tzu', 'King Charles Spaniel', 'Papillon', + 'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound', + 'Beagle', 'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound', + 'Treeing Walker Coonhound', 'English foxhound', 'Redbone Coonhound', + 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet', + 'Ibizan Hound', 'Norwegian Elkhound', 'Otterhound', 'Saluki', + 'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier', + 'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', + 'Kerry Blue Terrier', 'Irish Terrier', 'Norfolk Terrier', + 'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier', + 'Lakeland Terrier', 'Sealyham Terrier', 'Airedale Terrier', + 'Cairn Terrier', 'Australian Terrier', 'Dandie Dinmont Terrier', + 'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer', + 'Standard Schnauzer', 'Scottish Terrier', 'Tibetan Terrier', + 'Australian Silky Terrier', 'Soft-coated Wheaten Terrier', + 'West Highland White Terrier', 'Lhasa Apso', 'Flat-Coated Retriever', + 'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever', + 'Chesapeake Bay Retriever', 'German Shorthaired Pointer', 'Vizsla', + 'English Setter', 'Irish Setter', 'Gordon Setter', 'Brittany dog', + 'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel', + 'Cocker Spaniel', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz', + 'Schipperke', 'Groenendael dog', 'Malinois', 'Briard', 'Australian Kelpie', + 'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie', + 'Border Collie', 'Bouvier des Flandres dog', 'Rottweiler', + 'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher', + 'Greater Swiss Mountain Dog', 'Bernese Mountain Dog', + 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer', 'Bullmastiff', + 'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky', + 'Alaskan Malamute', 'Siberian Husky', 'Dalmatian', 'Affenpinscher', + 'Basenji', 'pug', 'Leonberger', 'Newfoundland dog', 'Great Pyrenees dog', + 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'brussels griffon', + 'Pembroke Welsh Corgi', 'Cardigan Welsh Corgi', 'Toy Poodle', + 'Miniature Poodle', 'Standard Poodle', + 'Mexican hairless dog (xoloitzcuintli)', 'grey wolf', + 'Alaskan tundra wolf', 'red wolf or maned wolf', 'coyote', 'dingo', + 'dhole', 'African wild dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox', + 'grey fox', 'tabby cat', 'tiger cat', 'Persian cat', 'Siamese cat', + 'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar', + 'lion', 'tiger', 'cheetah', 'brown bear', 'American black bear', + 'polar bear', 'sloth bear', 'mongoose', 'meerkat', 'tiger beetle', + 'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle', + 'dung beetle', 'rhinoceros beetle', 'weevil', 'fly', 'bee', 'ant', + 'grasshopper', 'cricket insect', 'stick insect', 'cockroach', + 'praying mantis', 'cicada', 'leafhopper', 'lacewing', 'dragonfly', + 'damselfly', 'red admiral butterfly', 'ringlet butterfly', + 'monarch butterfly', 'small white butterfly', 'sulphur butterfly', + 'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber', + 'cottontail rabbit', 'hare', 'Angora rabbit', 'hamster', 'porcupine', + 'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'common sorrel horse', + 'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus', 'ox', + 'water buffalo', 'bison', 'ram (adult male sheep)', 'bighorn sheep', + 'Alpine ibex', 'hartebeest', 'impala (antelope)', 'gazelle', + 'arabian camel', 'llama', 'weasel', 'mink', 'European polecat', + 'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', + 'three-toed sloth', 'orangutan', 'gorilla', 'chimpanzee', 'gibbon', + 'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur', + 'black-and-white colobus', 'proboscis monkey', 'marmoset', + 'white-headed capuchin', 'howler monkey', 'titi monkey', + "Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur', + 'indri', 'Asian elephant', 'African bush elephant', 'red panda', + 'giant panda', 'snoek fish', 'eel', 'silver salmon', 'rock beauty fish', + 'clownfish', 'sturgeon', 'gar fish', 'lionfish', 'pufferfish', 'abacus', + 'abaya', 'academic gown', 'accordion', 'acoustic guitar', + 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', + 'amphibious vehicle', 'analog clock', 'apiary', 'apron', 'trash can', + 'assault rifle', 'backpack', 'bakery', 'balance beam', 'balloon', + 'ballpoint pen', 'Band-Aid', 'banjo', 'baluster / handrail', 'barbell', + 'barber chair', 'barbershop', 'barn', 'barometer', 'barrel', 'wheelbarrow', + 'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap', + 'bath towel', 'bathtub', 'station wagon', 'lighthouse', 'beaker', + 'military hat (bearskin or shako)', 'beer bottle', 'beer glass', + 'bell tower', 'baby bib', 'tandem bicycle', 'bikini', 'ring binder', + 'binoculars', 'birdhouse', 'boathouse', 'bobsleigh', 'bolo tie', + 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'hunting bow', + 'bow tie', 'brass memorial plaque', 'bra', 'breakwater', 'breastplate', + 'broom', 'bucket', 'buckle', 'bulletproof vest', 'high-speed train', + 'butcher shop', 'taxicab', 'cauldron', 'candle', 'cannon', 'canoe', + 'can opener', 'cardigan', 'car mirror', 'carousel', 'tool kit', + 'cardboard box / carton', 'car wheel', 'automated teller machine', + 'cassette', 'cassette player', 'castle', 'catamaran', 'CD player', 'cello', + 'mobile phone', 'chain', 'chain-link fence', 'chain mail', 'chainsaw', + 'storage chest', 'chiffonier', 'bell or wind chime', 'china cabinet', + 'Christmas stocking', 'church', 'movie theater', 'cleaver', + 'cliff dwelling', 'cloak', 'clogs', 'cocktail shaker', 'coffee mug', + 'coffeemaker', 'spiral or coil', 'combination lock', 'computer keyboard', + 'candy store', 'container ship', 'convertible', 'corkscrew', 'cornet', + 'cowboy boot', 'cowboy hat', 'cradle', 'construction crane', + 'crash helmet', 'crate', 'infant bed', 'Crock Pot', 'croquet ball', + 'crutch', 'cuirass', 'dam', 'desk', 'desktop computer', + 'rotary dial telephone', 'diaper', 'digital clock', 'digital watch', + 'dining table', 'dishcloth', 'dishwasher', 'disc brake', 'dock', + 'dog sled', 'dome', 'doormat', 'drilling rig', 'drum', 'drumstick', + 'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', + 'electric locomotive', 'entertainment center', 'envelope', + 'espresso machine', 'face powder', 'feather boa', 'filing cabinet', + 'fireboat', 'fire truck', 'fire screen', 'flagpole', 'flute', + 'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen', + 'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat', + 'garbage truck', 'gas mask or respirator', 'gas pump', 'goblet', 'go-kart', + 'golf ball', 'golf cart', 'gondola', 'gong', 'gown', 'grand piano', + 'greenhouse', 'radiator grille', 'grocery store', 'guillotine', + 'hair clip', 'hair spray', 'half-track', 'hammer', 'hamper', 'hair dryer', + 'hand-held computer', 'handkerchief', 'hard disk drive', 'harmonica', + 'harp', 'combine harvester', 'hatchet', 'holster', 'home theater', + 'honeycomb', 'hook', 'hoop skirt', 'gymnastic horizontal bar', + 'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron', + 'carved pumpkin', 'jeans', 'jeep', 'T-shirt', 'jigsaw puzzle', 'rickshaw', + 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade', + 'laptop computer', 'lawn mower', 'lens cap', 'letter opener', 'library', + 'lifeboat', 'lighter', 'limousine', 'ocean liner', 'lipstick', + 'slip-on shoe', 'lotion', 'music speaker', 'loupe magnifying glass', + 'sawmill', 'magnetic compass', 'messenger bag', 'mailbox', 'tights', + 'one-piece bathing suit', 'manhole cover', 'maraca', 'marimba', 'mask', + 'matchstick', 'maypole', 'maze', 'measuring cup', 'medicine cabinet', + 'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can', + 'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl', + 'mobile home', 'ford model t', 'modem', 'monastery', 'monitor', 'moped', + 'mortar and pestle', 'graduation cap', 'mosque', 'mosquito net', 'vespa', + 'mountain bike', 'tent', 'computer mouse', 'mousetrap', 'moving van', + 'muzzle', 'metal nail', 'neck brace', 'necklace', 'baby pacifier', + 'notebook computer', 'obelisk', 'oboe', 'ocarina', 'odometer', + 'oil filter', 'pipe organ', 'oscilloscope', 'overskirt', 'bullock cart', + 'oxygen mask', 'product packet / packaging', 'paddle', 'paddle wheel', + 'padlock', 'paintbrush', 'pajamas', 'palace', 'pan flute', 'paper towel', + 'parachute', 'parallel bars', 'park bench', 'parking meter', + 'railroad car', 'patio', 'payphone', 'pedestal', 'pencil case', + 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'plectrum', + 'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank', + 'pill bottle', 'pillow', 'ping-pong ball', 'pinwheel', 'pirate ship', + 'drink pitcher', 'block plane', 'planetarium', 'plastic bag', 'plate rack', + 'farm plow', 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho', + 'pool table', 'soda bottle', 'plant pot', "potter's wheel", 'power drill', + 'prayer rug', 'printer', 'prison', 'missile', 'projector', 'hockey puck', + 'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket', + 'radiator', 'radio', 'radio telescope', 'rain barrel', + 'recreational vehicle', 'fishing casting reel', 'reflex camera', + 'refrigerator', 'remote control', 'restaurant', 'revolver', 'rifle', + 'rocking chair', 'rotisserie', 'eraser', 'rugby ball', + 'ruler measuring stick', 'sneaker', 'safe', 'safety pin', 'salt shaker', + 'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale', + 'school bus', 'schooner', 'scoreboard', 'CRT monitor', 'screw', + 'screwdriver', 'seat belt', 'sewing machine', 'shield', 'shoe store', + 'shoji screen / room divider', 'shopping basket', 'shopping cart', + 'shovel', 'shower cap', 'shower curtain', 'ski', 'balaclava ski mask', + 'sleeping bag', 'slide rule', 'sliding door', 'slot machine', 'snorkel', + 'snowmobile', 'snowplow', 'soap dispenser', 'soccer ball', 'sock', + 'solar thermal collector', 'sombrero', 'soup bowl', 'keyboard space bar', + 'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web', + 'spindle', 'sports car', 'spotlight', 'stage', 'steam locomotive', + 'through arch bridge', 'steel drum', 'stethoscope', 'scarf', 'stone wall', + 'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa', + 'submarine', 'suit', 'sundial', 'sunglasses', 'sunglasses', 'sunscreen', + 'suspension bridge', 'mop', 'sweatshirt', 'swim trunks / shorts', 'swing', + 'electrical switch', 'syringe', 'table lamp', 'tank', 'tape player', + 'teapot', 'teddy bear', 'television', 'tennis ball', 'thatched roof', + 'front curtain', 'thimble', 'threshing machine', 'throne', 'tile roof', + 'toaster', 'tobacco shop', 'toilet seat', 'torch', 'totem pole', + 'tow truck', 'toy store', 'tractor', 'semi-trailer truck', 'tray', + 'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch', + 'trolleybus', 'trombone', 'hot tub', 'turnstile', 'typewriter keyboard', + 'umbrella', 'unicycle', 'upright piano', 'vacuum cleaner', 'vase', + 'vaulted or arched ceiling', 'velvet fabric', 'vending machine', + 'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock', + 'wallet', 'wardrobe', 'military aircraft', 'sink', 'washing machine', + 'water bottle', 'water jug', 'water tower', 'whiskey jug', 'whistle', + 'hair wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle', + 'airplane wing', 'wok', 'wooden spoon', 'wool', 'split-rail fence', + 'shipwreck', 'sailboat', 'yurt', 'website', 'comic book', 'crossword', + 'traffic or street sign', 'traffic light', 'dust jacket', 'menu', 'plate', + 'guacamole', 'consomme', 'hot pot', 'trifle', 'ice cream', 'popsicle', + 'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog', + 'mashed potatoes', 'cabbage', 'broccoli', 'cauliflower', 'zucchini', + 'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber', + 'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith apple', + 'strawberry', 'orange', 'lemon', 'fig', 'pineapple', 'banana', 'jackfruit', + 'cherimoya (custard apple)', 'pomegranate', 'hay', 'carbonara', + 'chocolate syrup', 'dough', 'meatloaf', 'pizza', 'pot pie', 'burrito', + 'red wine', 'espresso', 'tea cup', 'eggnog', 'mountain', 'bubble', 'cliff', + 'coral reef', 'geyser', 'lakeshore', 'promontory', 'sandbar', 'beach', + 'valley', 'volcano', 'baseball player', 'bridegroom', 'scuba diver', + 'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn', 'rose hip', + 'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra', + 'stinkhorn mushroom', 'earth star fungus', 'hen of the woods mushroom', + 'bolete', 'corn cob', 'toilet paper') diff --git a/mmpretrain/datasets/cifar.py b/mmpretrain/datasets/cifar.py new file mode 100644 index 0000000..2a011da --- /dev/null +++ b/mmpretrain/datasets/cifar.py @@ -0,0 +1,210 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pickle +from typing import List, Optional + +import mmengine.dist as dist +import numpy as np +from mmengine.fileio import (LocalBackend, exists, get, get_file_backend, + join_path) +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import CIFAR10_CATEGORIES, CIFAR100_CATEGORIES +from .utils import check_md5, download_and_extract_archive + + +@DATASETS.register_module() +class CIFAR10(BaseDataset): + """`CIFAR10 `_ Dataset. + + This implementation is modified from + https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py + + Args: + data_root (str): The root directory of the CIFAR Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ # noqa: E501 + + base_folder = 'cifar-10-batches-py' + url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' + filename = 'cifar-10-python.tar.gz' + tgz_md5 = 'c58f30108f718f92721af3b95e74349a' + train_list = [ + ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], + ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], + ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], + ['data_batch_4', '634d18415352ddfa80567beed471001a'], + ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], + ] + + test_list = [ + ['test_batch', '40351d587109b95175f43aff81a1287e'], + ] + meta = { + 'filename': 'batches.meta', + 'key': 'label_names', + 'md5': '5ff9c542aee3614f3951f8cda6e48888', + } + METAINFO = {'classes': CIFAR10_CATEGORIES} + + def __init__(self, + data_root: str = '', + split: str = 'train', + metainfo: Optional[dict] = None, + download: bool = True, + data_prefix: str = '', + test_mode: bool = False, + **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + # To handle the BC-breaking + if split == 'train' and test_mode: + logger = MMLogger.get_current_instance() + logger.warning('split="train" but test_mode=True. ' + 'The training set will be used.') + + if not data_root and not data_prefix: + raise RuntimeError('Please set ``data_root`` to' + 'specify the dataset path') + + self.download = download + super().__init__( + # The CIFAR dataset doesn't need specify annotation file + ann_file='', + metainfo=metainfo, + data_root=data_root, + data_prefix=dict(root=data_prefix), + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + root = self.data_prefix['root'] + backend = get_file_backend(root, enable_singleton=True) + + if dist.is_main_process() and not self._check_integrity(): + if not isinstance(backend, LocalBackend): + raise RuntimeError(f'The dataset on {root} is not integrated, ' + f'please manually handle it.') + + if self.download: + download_and_extract_archive( + self.url, root, filename=self.filename, md5=self.tgz_md5) + else: + raise RuntimeError( + f'Cannot find {self.__class__.__name__} dataset in ' + f"{self.data_prefix['root']}, you can specify " + '`download=True` to download automatically.') + + dist.barrier() + assert self._check_integrity(), \ + 'Download failed or shared storage is unavailable. Please ' \ + f'download the dataset manually through {self.url}.' + + if self.split == 'train': + downloaded_list = self.train_list + else: + downloaded_list = self.test_list + + imgs = [] + gt_labels = [] + + # load the picked numpy arrays + for file_name, _ in downloaded_list: + file_path = join_path(root, self.base_folder, file_name) + entry = pickle.loads(get(file_path), encoding='latin1') + imgs.append(entry['data']) + if 'labels' in entry: + gt_labels.extend(entry['labels']) + else: + gt_labels.extend(entry['fine_labels']) + + imgs = np.vstack(imgs).reshape(-1, 3, 32, 32) + imgs = imgs.transpose((0, 2, 3, 1)) # convert to HWC + + if self.CLASSES is None: + # The metainfo in the file has the lowest priority, therefore + # we only need to load it if classes is not specified. + self._load_meta() + + data_list = [] + for img, gt_label in zip(imgs, gt_labels): + info = {'img': img, 'gt_label': int(gt_label)} + data_list.append(info) + return data_list + + def _load_meta(self): + """Load categories information from metafile.""" + root = self.data_prefix['root'] + + path = join_path(root, self.base_folder, self.meta['filename']) + md5 = self.meta.get('md5', None) + if not exists(path) or (md5 is not None and not check_md5(path, md5)): + raise RuntimeError( + 'Dataset metadata file not found or corrupted.' + + ' You can use `download=True` to download it') + data = pickle.loads(get(path), encoding='latin1') + self._metainfo.setdefault('classes', data[self.meta['key']]) + + def _check_integrity(self): + """Check the integrity of data files.""" + root = self.data_prefix['root'] + + for fentry in (self.train_list + self.test_list): + filename, md5 = fentry[0], fentry[1] + fpath = join_path(root, self.base_folder, filename) + if not exists(fpath): + return False + if md5 is not None and not check_md5(fpath, md5): + return False + return True + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [f"Prefix of data: \t{self.data_prefix['root']}"] + return body + + +@DATASETS.register_module() +class CIFAR100(CIFAR10): + """`CIFAR100 `_ Dataset. + + Args: + data_root (str): The root directory of the CIFAR Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + base_folder = 'cifar-100-python' + url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' + filename = 'cifar-100-python.tar.gz' + tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' + train_list = [ + ['train', '16019d7e3df5f24257cddd939b257f8d'], + ] + + test_list = [ + ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], + ] + meta = { + 'filename': 'meta', + 'key': 'fine_label_names', + 'md5': '7973b15100ade9c7d40fb424638fde48', + } + METAINFO = {'classes': CIFAR100_CATEGORIES} diff --git a/mmpretrain/datasets/coco_caption.py b/mmpretrain/datasets/coco_caption.py new file mode 100644 index 0000000..541cda8 --- /dev/null +++ b/mmpretrain/datasets/coco_caption.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class COCOCaption(BaseDataset): + """COCO Caption dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``.. + ann_file (str): Annotation file path. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + annotations = mmengine.load(self.ann_file) + file_backend = get_file_backend(img_prefix) + + data_list = [] + for ann in annotations: + data_info = { + 'image_id': Path(ann['image']).stem.split('_')[-1], + 'img_path': file_backend.join_path(img_prefix, ann['image']), + 'gt_caption': ann['caption'], + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/coco_retrieval.py b/mmpretrain/datasets/coco_retrieval.py new file mode 100644 index 0000000..be8a0bc --- /dev/null +++ b/mmpretrain/datasets/coco_retrieval.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from collections import OrderedDict +from os import PathLike +from typing import List, Sequence, Union + +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS, TRANSFORMS +from .base_dataset import BaseDataset + + +def expanduser(data_prefix): + if isinstance(data_prefix, (str, PathLike)): + return osp.expanduser(data_prefix) + else: + return data_prefix + + +@DATASETS.register_module() +class COCORetrieval(BaseDataset): + """COCO Retrieval dataset. + + COCO (Common Objects in Context): The COCO dataset contains more than + 330K images,each of which has approximately 5 descriptive annotations. + This dataset was releasedin collaboration between Microsoft and Carnegie + Mellon University + + COCO_2014 dataset directory: :: + + COCO_2014 + ├── val2014 + ├── train2014 + ├── annotations + ├── instances_train2014.json + ├── instances_val2014.json + ├── person_keypoints_train2014.json + ├── person_keypoints_val2014.json + ├── captions_train2014.json + ├── captions_val2014.json + + Args: + ann_file (str): Annotation file path. + test_mode (bool): Whether dataset is used for evaluation. This will + decide the annotation format in data list annotations. + Defaults to False. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import COCORetrieval + >>> train_dataset=COCORetrieval(data_root='coco2014/') + >>> train_dataset + Dataset COCORetrieval + Number of samples: 414113 + Annotation file: /coco2014/annotations/captions_train2014.json + Prefix of images: /coco2014/ + >>> from mmpretrain.datasets import COCORetrieval + >>> val_dataset = COCORetrieval(data_root='coco2014/') + >>> val_dataset + Dataset COCORetrieval + Number of samples: 202654 + Annotation file: /coco2014/annotations/captions_val2014.json + Prefix of images: /coco2014/ + """ + + def __init__(self, + ann_file: str, + test_mode: bool = False, + data_prefix: Union[str, dict] = '', + data_root: str = '', + pipeline: Sequence = (), + **kwargs): + + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + + ann_file = expanduser(ann_file) + transforms = [] + for transform in pipeline: + if isinstance(transform, dict): + transforms.append(TRANSFORMS.build(transform)) + else: + transforms.append(transform) + + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + pipeline=transforms, + ann_file=ann_file, + **kwargs, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + # get file backend + img_prefix = self.data_prefix['img_path'] + file_backend = get_file_backend(img_prefix) + + anno_info = json.load(open(self.ann_file, 'r')) + # mapping img_id to img filename + img_dict = OrderedDict() + for idx, img in enumerate(anno_info['images']): + if img['id'] not in img_dict: + img_rel_path = img['coco_url'].rsplit('/', 2)[-2:] + img_path = file_backend.join_path(img_prefix, *img_rel_path) + + # create new idx for image + img_dict[img['id']] = dict( + ori_id=img['id'], + image_id=idx, # will be used for evaluation + img_path=img_path, + text=[], + gt_text_id=[], + gt_image_id=[], + ) + + train_list = [] + for idx, anno in enumerate(anno_info['annotations']): + anno['text'] = anno.pop('caption') + anno['ori_id'] = anno.pop('id') + anno['text_id'] = idx # will be used for evaluation + # 1. prepare train data list item + train_data = anno.copy() + train_image = img_dict[train_data['image_id']] + train_data['img_path'] = train_image['img_path'] + train_data['image_ori_id'] = train_image['ori_id'] + train_data['image_id'] = train_image['image_id'] + train_data['is_matched'] = True + train_list.append(train_data) + # 2. prepare eval data list item based on img dict + img_dict[anno['image_id']]['gt_text_id'].append(anno['text_id']) + img_dict[anno['image_id']]['text'].append(anno['text']) + img_dict[anno['image_id']]['gt_image_id'].append( + train_image['image_id']) + + self.img_size = len(img_dict) + self.text_size = len(anno_info['annotations']) + + # return needed format data list + if self.test_mode: + return list(img_dict.values()) + return train_list diff --git a/mmpretrain/datasets/coco_vqa.py b/mmpretrain/datasets/coco_vqa.py new file mode 100644 index 0000000..85f4bdc --- /dev/null +++ b/mmpretrain/datasets/coco_vqa.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import re +from collections import Counter +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class COCOVQA(BaseDataset): + """VQAv2 dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + question_file (str): Question file path. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + question_file: str, + ann_file: str = '', + **kwarg): + self.question_file = question_file + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def _join_prefix(self): + if not mmengine.is_abs(self.question_file) and self.question_file: + self.question_file = osp.join(self.data_root, self.question_file) + + return super()._join_prefix() + + def _create_image_index(self): + img_prefix = self.data_prefix['img_path'] + + files = mmengine.list_dir_or_file(img_prefix, list_dir=False) + image_index = {} + for file in files: + image_id = re.findall(r'\d{12}', file) + if len(image_id) > 0: + image_id = int(image_id[-1]) + image_index[image_id] = mmengine.join_path(img_prefix, file) + + return image_index + + def load_data_list(self) -> List[dict]: + """Load data list.""" + questions = mmengine.load(self.question_file)['questions'] + if self.ann_file: + annotations = mmengine.load(self.ann_file)['annotations'] + assert len(questions) == len(annotations) + else: + annotations = [None] * len(questions) + + # The original VQAv2 annotation file and question file includes + # only image id but no image file paths. + self.image_index = self._create_image_index() + + data_list = [] + for question, ann in zip(questions, annotations): + # question example + # { + # 'image_id': 262144, + # 'question': "Is the ball flying towards the batter?", + # 'question_id': 262144000 + # } + # + # ann example + # { + # 'question_type': "what are the", + # 'answer_type': "other", + # 'answers': [ + # {'answer': 'watching', + # 'answer_id': 1, + # 'answer_confidence': 'yes'}, + # ... + # ], + # 'image_id': 262148, + # 'question_id': 262148000, + # 'multiple_choice_answer': 'watching', + # 'answer_type': 'other', + # } + + data_info = question + data_info['img_path'] = self.image_index[question['image_id']] + + if ann is not None: + assert ann['question_id'] == question['question_id'] + + # add answer_weight & answer_count, delete duplicate answer + answers = [item['answer'] for item in ann.pop('answers')] + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + data_info['gt_answer'] = list(count.keys()) + data_info['gt_answer_weight'] = answer_weight + data_info.update(ann) + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/cub.py b/mmpretrain/datasets/cub.py new file mode 100644 index 0000000..8db1262 --- /dev/null +++ b/mmpretrain/datasets/cub.py @@ -0,0 +1,142 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import CUB_CATEGORIES + + +@DATASETS.register_module() +class CUB(BaseDataset): + """The CUB-200-2011 Dataset. + + Support the `CUB-200-2011 `_ Dataset. + Comparing with the `CUB-200 `_ Dataset, + there are much more pictures in `CUB-200-2011`. After downloading and decompression, the dataset + directory structure is as follows. + + CUB dataset directory: :: + + CUB_200_2011 + ├── images + │ ├── class_x + │ │ ├── xx1.jpg + │ │ ├── xx2.jpg + │ │ └── ... + │ ├── class_y + │ │ ├── yy1.jpg + │ │ ├── yy2.jpg + │ │ └── ... + │ └── ... + ├── images.txt + ├── image_class_labels.txt + ├── train_test_split.txt + └── .... + + Args: + data_root (str): The root directory for CUB-200-2011 dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import CUB + >>> train_dataset = CUB(data_root='data/CUB_200_2011', split='train') + >>> train_dataset + Dataset CUB + Number of samples: 5994 + Number of categories: 200 + Root of dataset: data/CUB_200_2011 + >>> test_dataset = CUB(data_root='data/CUB_200_2011', split='test') + >>> test_dataset + Dataset CUB + Number of samples: 5794 + Number of categories: 200 + Root of dataset: data/CUB_200_2011 + """ # noqa: E501 + + METAINFO = {'classes': CUB_CATEGORIES} + + def __init__(self, + data_root: str, + split: str = 'train', + test_mode: bool = False, + **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + # To handle the BC-breaking + if split == 'train' and test_mode: + logger = MMLogger.get_current_instance() + logger.warning('split="train" but test_mode=True. ' + 'The training set will be used.') + + ann_file = 'images.txt' + data_prefix = 'images' + image_class_labels_file = 'image_class_labels.txt' + train_test_split_file = 'train_test_split.txt' + + self.backend = get_file_backend(data_root, enable_singleton=True) + self.image_class_labels_file = self.backend.join_path( + data_root, image_class_labels_file) + self.train_test_split_file = self.backend.join_path( + data_root, train_test_split_file) + super(CUB, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def _load_data_from_txt(self, filepath): + """load data from CUB txt file, the every line of the file is idx and a + data item.""" + pairs = list_from_file(filepath) + data_dict = dict() + for pair in pairs: + idx, data_item = pair.split() + # all the index starts from 1 in CUB files, + # here we need to '- 1' to let them start from 0. + data_dict[int(idx) - 1] = data_item + return data_dict + + def load_data_list(self): + """Load images and ground truth labels.""" + sample_dict = self._load_data_from_txt(self.ann_file) + + label_dict = self._load_data_from_txt(self.image_class_labels_file) + + split_dict = self._load_data_from_txt(self.train_test_split_file) + + assert sample_dict.keys() == label_dict.keys() == split_dict.keys(),\ + f'sample_ids should be same in files {self.ann_file}, ' \ + f'{self.image_class_labels_file} and {self.train_test_split_file}' + + data_list = [] + for sample_id in sample_dict.keys(): + if split_dict[sample_id] == '1' and self.split == 'test': + # skip train samples when split='test' + continue + elif split_dict[sample_id] == '0' and self.split == 'train': + # skip test samples when split='train' + continue + + img_path = self.backend.join_path(self.img_prefix, + sample_dict[sample_id]) + gt_label = int(label_dict[sample_id]) - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/custom.py b/mmpretrain/datasets/custom.py new file mode 100644 index 0000000..bb491ff --- /dev/null +++ b/mmpretrain/datasets/custom.py @@ -0,0 +1,287 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +from mmengine.fileio import (BaseStorageBackend, get_file_backend, + list_from_file) +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +def find_folders( + root: str, + backend: Optional[BaseStorageBackend] = None +) -> Tuple[List[str], Dict[str, int]]: + """Find classes by folders under a root. + + Args: + root (string): root directory of folders + backend (BaseStorageBackend | None): The file backend of the root. + If None, auto infer backend from the root path. Defaults to None. + + Returns: + Tuple[List[str], Dict[str, int]]: + + - folders: The name of sub folders under the root. + - folder_to_idx: The map from folder name to class idx. + """ + # Pre-build file backend to prevent verbose file backend inference. + backend = backend or get_file_backend(root, enable_singleton=True) + folders = list( + backend.list_dir_or_file( + root, + list_dir=True, + list_file=False, + recursive=False, + )) + folders.sort() + folder_to_idx = {folders[i]: i for i in range(len(folders))} + return folders, folder_to_idx + + +def get_samples( + root: str, + folder_to_idx: Dict[str, int], + is_valid_file: Callable, + backend: Optional[BaseStorageBackend] = None, +): + """Make dataset by walking all images under a root. + + Args: + root (string): root directory of folders + folder_to_idx (dict): the map from class name to class idx + is_valid_file (Callable): A function that takes path of a file + and check if the file is a valid sample file. + backend (BaseStorageBackend | None): The file backend of the root. + If None, auto infer backend from the root path. Defaults to None. + + Returns: + Tuple[list, set]: + + - samples: a list of tuple where each element is (image, class_idx) + - empty_folders: The folders don't have any valid files. + """ + samples = [] + available_classes = set() + # Pre-build file backend to prevent verbose file backend inference. + backend = backend or get_file_backend(root, enable_singleton=True) + + if folder_to_idx is not None: + for folder_name in sorted(list(folder_to_idx.keys())): + _dir = backend.join_path(root, folder_name) + files = backend.list_dir_or_file( + _dir, + list_dir=False, + list_file=True, + recursive=True, + ) + for file in sorted(list(files)): + if is_valid_file(file): + path = backend.join_path(folder_name, file) + item = (path, folder_to_idx[folder_name]) + samples.append(item) + available_classes.add(folder_name) + empty_folders = set(folder_to_idx.keys()) - available_classes + else: + files = backend.list_dir_or_file( + root, + list_dir=False, + list_file=True, + recursive=True, + ) + samples = [file for file in sorted(list(files)) if is_valid_file(file)] + empty_folders = None + + return samples, empty_folders + + +@DATASETS.register_module() +class CustomDataset(BaseDataset): + """A generic dataset for multiple tasks. + + The dataset supports two kinds of style. + + 1. Use an annotation file to specify all samples, and each line indicates a + sample: + + The annotation file (for ``with_label=True``, supervised tasks.): :: + + folder_1/xxx.png 0 + folder_1/xxy.png 1 + 123.png 4 + nsdf3.png 3 + ... + + The annotation file (for ``with_label=False``, unsupervised tasks.): :: + + folder_1/xxx.png + folder_1/xxy.png + 123.png + nsdf3.png + ... + + Sample files: :: + + data_prefix/ + ├── folder_1 + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + ├── 123.png + ├── nsdf3.png + └── ... + + Please use the argument ``metainfo`` to specify extra information for + the task, like ``{'classes': ('bird', 'cat', 'deer', 'dog', 'frog')}``. + + 2. Place all samples in one folder as below: + + Sample files (for ``with_label=True``, supervised tasks, we use the name + of sub-folders as the categories names): :: + + data_prefix/ + ├── class_x + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + │ └── xxz.png + └── class_y + ├── 123.png + ├── nsdf3.png + ├── ... + └── asd932_.png + + Sample files (for ``with_label=False``, unsupervised tasks, we use all + sample files under the specified folder): :: + + data_prefix/ + ├── folder_1 + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + ├── 123.png + ├── nsdf3.png + └── ... + + If the ``ann_file`` is specified, the dataset will be generated by the + first way, otherwise, try the second way. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for the data. Defaults to ''. + ann_file (str): Annotation file path. Defaults to ''. + with_label (bool): Whether the annotation file includes ground truth + labels, or use sub-folders to specify categories. + Defaults to True. + extensions (Sequence[str]): A sequence of allowed extensions. Defaults + to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'). + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + lazy_init (bool): Whether to load annotation during instantiation. + In some cases, such as visualization, only the meta information of + the dataset is needed, which is not necessary to load annotation + file. ``Basedataset`` can skip load annotations to save time by set + ``lazy_init=False``. Defaults to False. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + with_label=True, + extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', + '.bmp', '.pgm', '.tif'), + metainfo: Optional[dict] = None, + lazy_init: bool = False, + **kwargs): + assert (ann_file or data_prefix or data_root), \ + 'One of `ann_file`, `data_root` and `data_prefix` must '\ + 'be specified.' + + self.extensions = tuple(set([i.lower() for i in extensions])) + self.with_label = with_label + + super().__init__( + # The base class requires string ann_file but this class doesn't + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + # Force to lazy_init for some modification before loading data. + lazy_init=True, + **kwargs) + + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + def _find_samples(self): + """find samples from ``data_prefix``.""" + if self.with_label: + classes, folder_to_idx = find_folders(self.img_prefix) + samples, empty_classes = get_samples( + self.img_prefix, + folder_to_idx, + is_valid_file=self.is_valid_file, + ) + + self.folder_to_idx = folder_to_idx + + if self.CLASSES is not None: + assert len(self.CLASSES) == len(classes), \ + f"The number of subfolders ({len(classes)}) doesn't " \ + f'match the number of specified classes ' \ + f'({len(self.CLASSES)}). Please check the data folder.' + else: + self._metainfo['classes'] = tuple(classes) + else: + samples, empty_classes = get_samples( + self.img_prefix, + None, + is_valid_file=self.is_valid_file, + ) + + if len(samples) == 0: + raise RuntimeError( + f'Found 0 files in subfolders of: {self.data_prefix}. ' + f'Supported extensions are: {",".join(self.extensions)}') + + if empty_classes: + logger = MMLogger.get_current_instance() + logger.warning( + 'Found no valid file in the folder ' + f'{", ".join(empty_classes)}. ' + f"Supported extensions are: {', '.join(self.extensions)}") + + return samples + + def load_data_list(self): + """Load image paths and gt_labels.""" + if not self.ann_file: + samples = self._find_samples() + elif self.with_label: + lines = list_from_file(self.ann_file) + samples = [x.strip().rsplit(' ', 1) for x in lines] + else: + samples = list_from_file(self.ann_file) + + # Pre-build file backend to prevent verbose file backend inference. + backend = get_file_backend(self.img_prefix, enable_singleton=True) + data_list = [] + for sample in samples: + if self.with_label: + filename, gt_label = sample + img_path = backend.join_path(self.img_prefix, filename) + info = {'img_path': img_path, 'gt_label': int(gt_label)} + else: + img_path = backend.join_path(self.img_prefix, sample) + info = {'img_path': img_path} + data_list.append(info) + return data_list + + def is_valid_file(self, filename: str) -> bool: + """Check if a file is a valid sample.""" + return filename.lower().endswith(self.extensions) diff --git a/mmpretrain/datasets/dataset_wrappers.py b/mmpretrain/datasets/dataset_wrappers.py new file mode 100644 index 0000000..1adff10 --- /dev/null +++ b/mmpretrain/datasets/dataset_wrappers.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import numpy as np +from mmengine.dataset import BaseDataset, force_full_init + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class KFoldDataset: + """A wrapper of dataset for K-Fold cross-validation. + + K-Fold cross-validation divides all the samples in groups of samples, + called folds, of almost equal sizes. And we use k-1 of folds to do training + and use the fold left to do validation. + + Args: + dataset (:obj:`mmengine.dataset.BaseDataset` | dict): The dataset to be + divided + fold (int): The fold used to do validation. Defaults to 0. + num_splits (int): The number of all folds. Defaults to 5. + test_mode (bool): Use the training dataset or validation dataset. + Defaults to False. + seed (int, optional): The seed to shuffle the dataset before splitting. + If None, not shuffle the dataset. Defaults to None. + """ + + def __init__(self, + dataset, + fold=0, + num_splits=5, + test_mode=False, + seed=None): + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + # Init the dataset wrapper lazily according to the dataset setting. + lazy_init = dataset.get('lazy_init', False) + elif isinstance(dataset, BaseDataset): + self.dataset = dataset + else: + raise TypeError(f'Unsupported dataset type {type(dataset)}.') + + self._metainfo = getattr(self.dataset, 'metainfo', {}) + self.fold = fold + self.num_splits = num_splits + self.test_mode = test_mode + self.seed = seed + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + @property + def metainfo(self) -> dict: + """Get the meta information of ``self.dataset``. + + Returns: + dict: Meta information of the dataset. + """ + # Prevent `self._metainfo` from being modified by outside. + return copy.deepcopy(self._metainfo) + + def full_init(self): + """fully initialize the dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + ori_len = len(self.dataset) + indices = list(range(ori_len)) + if self.seed is not None: + rng = np.random.default_rng(self.seed) + rng.shuffle(indices) + + test_start = ori_len * self.fold // self.num_splits + test_end = ori_len * (self.fold + 1) // self.num_splits + if self.test_mode: + indices = indices[test_start:test_end] + else: + indices = indices[:test_start] + indices[test_end:] + + self._ori_indices = indices + self.dataset = self.dataset.get_subset(indices) + + self._fully_initialized = True + + @force_full_init + def _get_ori_dataset_idx(self, idx: int) -> int: + """Convert global idx to local index. + + Args: + idx (int): Global index of ``KFoldDataset``. + + Returns: + int: The original index in the whole dataset. + """ + return self._ori_indices[idx] + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``KFoldDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + return self.dataset.get_data_info(idx) + + @force_full_init + def __len__(self): + return len(self.dataset) + + @force_full_init + def __getitem__(self, idx): + return self.dataset[idx] + + @force_full_init + def get_cat_ids(self, idx): + return self.dataset.get_cat_ids(idx) + + @force_full_init + def get_gt_labels(self): + return self.dataset.get_gt_labels() + + @property + def CLASSES(self): + """Return all categories names.""" + return self._metainfo.get('classes', None) + + @property + def class_to_idx(self): + """Map mapping class name to class index. + + Returns: + dict: mapping from class name to class index. + """ + + return {cat: i for i, cat in enumerate(self.CLASSES)} + + def __repr__(self): + """Print the basic information of the dataset. + + Returns: + str: Formatted string. + """ + head = 'Dataset ' + self.__class__.__name__ + body = [] + type_ = 'test' if self.test_mode else 'training' + body.append(f'Type: \t{type_}') + body.append(f'Seed: \t{self.seed}') + + def ordinal(n): + # Copy from https://codegolf.stackexchange.com/a/74047 + suffix = 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4] + return f'{n}{suffix}' + + body.append( + f'Fold: \t{ordinal(self.fold+1)} of {self.num_splits}-fold') + if self._fully_initialized: + body.append(f'Number of samples: \t{self.__len__()}') + else: + body.append("Haven't been initialized") + + if self.CLASSES is not None: + body.append(f'Number of categories: \t{len(self.CLASSES)}') + else: + body.append('The `CLASSES` meta info is not set.') + + body.append( + f'Original dataset type:\t{self.dataset.__class__.__name__}') + + lines = [head] + [' ' * 4 + line for line in body] + return '\n'.join(lines) diff --git a/mmpretrain/datasets/dtd.py b/mmpretrain/datasets/dtd.py new file mode 100644 index 0000000..034d0b1 --- /dev/null +++ b/mmpretrain/datasets/dtd.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mat4py +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import DTD_CATEGORIES + + +@DATASETS.register_module() +class DTD(BaseDataset): + """The Describable Texture Dataset (DTD). + + Support the `Describable Texture Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + DTD dataset directory: :: + + dtd + ├── images + │ ├── banded + | | ├──banded_0002.jpg + | | ├──banded_0004.jpg + | | └── ... + │ └── ... + ├── imdb + │ └── imdb.mat + ├── labels + | | ├──labels_joint_anno.txt + | | ├──test1.txt + | | ├──test2.txt + | | └── ... + │ └── ... + └── .... + + Args: + data_root (str): The root directory for Describable Texture dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import DTD + >>> train_dataset = DTD(data_root='data/dtd', split='trainval') + >>> train_dataset + Dataset DTD + Number of samples: 3760 + Number of categories: 47 + Root of dataset: data/dtd + >>> test_dataset = DTD(data_root='data/dtd', split='test') + >>> test_dataset + Dataset DTD + Number of samples: 1880 + Number of categories: 47 + Root of dataset: data/dtd + """ # noqa: E501 + + METAINFO = {'classes': DTD_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + data_prefix = 'images' + test_mode = split == 'test' + + self.backend = get_file_backend(data_root, enable_singleton=True) + ann_file = self.backend.join_path('imdb', 'imdb.mat') + + super(DTD, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + data = mat4py.loadmat(self.ann_file)['images'] + names = data['name'] + labels = data['class'] + parts = data['set'] + num = len(names) + assert num == len(labels) == len(parts), 'get error ann file' + + if self.split == 'train': + target_set = {1} + elif self.split == 'val': + target_set = {2} + elif self.split == 'test': + target_set = {3} + else: + target_set = {1, 2} + + data_list = [] + for i in range(num): + if parts[i] in target_set: + img_name = names[i] + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = labels[i] - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/fgvcaircraft.py b/mmpretrain/datasets/fgvcaircraft.py new file mode 100644 index 0000000..696992c --- /dev/null +++ b/mmpretrain/datasets/fgvcaircraft.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import FGVCAIRCRAFT_CATEGORIES + + +@DATASETS.register_module() +class FGVCAircraft(BaseDataset): + """The FGVC_Aircraft Dataset. + + Support the `FGVC_Aircraft Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + FGVC_Aircraft dataset directory: :: + + fgvc-aircraft-2013b + └── data + ├── images + │ ├── 1.jpg + │ ├── 2.jpg + │ └── ... + ├── images_variant_train.txt + ├── images_variant_test.txt + ├── images_variant_trainval.txt + ├── images_variant_val.txt + ├── variants.txt + └── .... + + Args: + data_root (str): The root directory for FGVC_Aircraft dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import FGVCAircraft + >>> train_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='trainval') + >>> train_dataset + Dataset FGVCAircraft + Number of samples: 6667 + Number of categories: 100 + Root of dataset: data/fgvc-aircraft-2013b + >>> test_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='test') + >>> test_dataset + Dataset FGVCAircraft + Number of samples: 3333 + Number of categories: 100 + Root of dataset: data/fgvc-aircraft-2013b + """ # noqa: E501 + + METAINFO = {'classes': FGVCAIRCRAFT_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + ann_file = self.backend.join_path('data', + f'images_variant_{split}.txt') + data_prefix = self.backend.join_path('data', 'images') + test_mode = split == 'test' + + super(FGVCAircraft, self).__init__( + ann_file=ann_file, + data_root=data_root, + test_mode=test_mode, + data_prefix=data_prefix, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + pair = pair.split() + img_name = pair[0] + class_name = ' '.join(pair[1:]) + img_name = f'{img_name}.jpg' + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = self.METAINFO['classes'].index(class_name) + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/flamingo.py b/mmpretrain/datasets/flamingo.py new file mode 100644 index 0000000..3b5745a --- /dev/null +++ b/mmpretrain/datasets/flamingo.py @@ -0,0 +1,295 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from abc import abstractmethod +from collections import Counter +from typing import List + +import mmengine +import numpy as np +from mmengine.dataset import BaseDataset +from pycocotools.coco import COCO + +from mmpretrain.registry import DATASETS +from .coco_vqa import COCOVQA + + +class FlamingoFewShotMixin: + """Flamingo fewshot eval dataset minin. + + Args: + num_shots (int): Number of shots to perform evaluation. + Defaults to 0. + Note: 0 does not mean a strict zero-shot in Flamingo setting. + It will use 2 only-text prompt without in context images. + num_support_examples (int): Number of support examples to get the + few shots from. Defaults to 2048. + num_query_examples (int): Number of query examples to perform the + final evaluation. Defaults to 5000. + incontext_prompt_temp (str): In context prompt template for few shot + examples. Defaults to ''. + final_prompt_temp (str): Final query prompt template. Defaults to ''. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + num_shots: int = 0, + num_support_examples: int = 2048, + num_query_examples: int = 5000, + incontext_prompt_temp: str = '', + final_prompt_temp: str = '', + **kwarg): + self.num_shots = num_shots + self.num_support_examples = num_support_examples + self.num_query_examples = num_query_examples + self.incontext_prompt_temp = incontext_prompt_temp + self.final_prompt_temp = final_prompt_temp + super().__init__(**kwarg) + + def get_subset_idx(self, total_num): + random_idx = np.random.choice( + total_num, + self.num_support_examples + self.num_query_examples, + replace=False) + + support_idx = random_idx[:self.num_support_examples] + query_idx = random_idx[self.num_support_examples:] + return support_idx, query_idx + + @abstractmethod + def parse_basic_anno(self, anno: dict) -> dict: + """Parse basic annotation for support and query set.""" + pass + + @abstractmethod + def parse_fewshot_anno(self, anno: dict, support_list: List) -> dict: + """Parse fewshot related annotation for query set with support list.""" + pass + + +@DATASETS.register_module() +class FlamingoEvalCOCOVQA(FlamingoFewShotMixin, COCOVQA): + """Flamingo few shot VQAv2 dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. + ann_file (str): Annotation file path. + question_file (str): Question file path. + num_shots (int): Number of shots to perform evaluation. + Defaults to 0. + Note: 0 does not mean a strict zero-shot in Flamingo setting. + It will use 2 only-text prompt without in context images. + num_support_examples (int): Number of support examples to get the + few shots from. Defaults to 2048. + num_query_examples (int): Number of query examples to perform the + final evaluation. Defaults to 5000. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + question_file: str, + ann_file: str = '', + num_shots: int = 0, + num_support_examples: int = 2048, + num_query_examples: int = 5000, + **kwarg): + super().__init__( + data_root=data_root, + question_file=question_file, + ann_file=ann_file, + num_shots=num_shots, + num_support_examples=num_support_examples, + num_query_examples=num_query_examples, + **kwarg) + + def parse_basic_anno(self, ann: dict) -> dict: + """Parse basic annotation for support and query set. + + Args: + anno (dict): Annotation for single example. + + Return: + dict: Parsed annotation for single example. + """ + if ann is None: + return {} + + answers = [a['answer'] for a in ann['answers']] + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + answer_info = { + 'gt_answer': list(count.keys()), + 'gt_answer_weight': answer_weight + } + return answer_info + + def parse_fewshot_anno(self, query: dict, support_list: List) -> dict: + """Parse fewshot related annotation for query set with support list. + + Args: + anno (dict): Annotation for single example. + support_list (List): List of support subset to subsample few shots. + + Return: + dict: Parsed annotation for single example. + """ + # prepare n shots examples + shots = random.sample(support_list, self.num_shots) + + # append image path for n shots + img_path = [shot['img_path'] for shot in shots] + img_path.append(query['img_path']) + query['img_path'] = img_path + + query['shots'] = [ + dict( + question=item['question'], + answer=item['gt_answer'][0], + ) for item in shots + ] + return query + + def load_data_list(self) -> List[dict]: + """Load data list.""" + questions = mmengine.load(self.question_file)['questions'] + if self.ann_file: + annotations = mmengine.load(self.ann_file)['annotations'] + assert len(questions) == len(annotations) + else: + annotations = [None] * len(questions) + if self.num_shots > 0: + raise ValueError('Unable to construct few-shot examples ' + 'since no annotation file.') + + # The original VQAv2 annotation file and question file includes + # only image id but no image file paths. + self.image_index = self._create_image_index() + + num_data = len(questions) + support_idx, query_idx = self.get_subset_idx(num_data) + + # prepare support subset + if self.num_shots > 0: + support_list = [] + for idx in support_idx: + question = questions[idx] + ann = annotations[idx] + support = {**question, **self.parse_basic_anno(ann)} + support['img_path'] = self.image_index[question['image_id']] + support_list.append(support) + + # prepare query subset + data_list = [] + for idx in query_idx: + question = questions[idx] + ann = annotations[idx] + data_info = {**question, **self.parse_basic_anno(ann)} + data_info['img_path'] = self.image_index[question['image_id']] + if self.num_shots > 0: + data_info = self.parse_fewshot_anno(data_info, support_list) + data_list.append(data_info) + + return data_list + + +@DATASETS.register_module() +class FlamingoEvalCOCOCaption(FlamingoFewShotMixin, BaseDataset): + """Flamingo few shot COCO Caption dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. + ann_file (str): Annotation file path. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + num_shots (int): Number of shots to perform evaluation. + Defaults to 0. + num_support_examples (int): Number of support examples to get the + few shots from. Defaults to 2048. + num_query_examples (int): Number of query examples to perform the + final evaluation. Defaults to 5000. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + ann_file: str, + num_shots: int = 0, + num_support_examples: int = 2048, + num_query_examples: int = 5000, + **kwarg): + super().__init__( + data_root=data_root, + ann_file=ann_file, + num_shots=num_shots, + num_support_examples=num_support_examples, + num_query_examples=num_query_examples, + **kwarg) + + def parse_basic_anno(self, ann: dict, coco: COCO) -> dict: + """Parse basic annotation for support and query set. + + Args: + anno (dict): Annotation for single example. + coco (COCO): The coco dataset. + + Return: + dict: Parsed annotation for single example. + """ + img_prefix = self.data_prefix['img_path'] + img = coco.imgs[ann['image_id']] + data_info = dict( + img_path=mmengine.join_path(img_prefix, img['file_name']), + gt_caption=ann['caption'], + image_id=ann['image_id'], + ) + return data_info + + def parse_fewshot_anno(self, query: dict, support_list: List) -> dict: + """Parse fewshot related annotation for query set with support list. + + Args: + query (dict): Annotation for single example. + support_list (List): List of support subset to subsample few shots. + coco (COCO): The coco dataset. + + Return: + dict: Parsed annotation for single example. + """ + # prepare n shots examples + shots = random.sample(support_list, self.num_shots) + + # append image path for n shots + img_path = [shot['img_path'] for shot in shots] + img_path.append(query['img_path']) + query['img_path'] = img_path + + query['shots'] = [dict(caption=item['gt_caption']) for item in shots] + return query + + def load_data_list(self) -> List[dict]: + """Load data list.""" + with mmengine.get_local_path(self.ann_file) as ann_file: + coco = COCO(ann_file) + + num_data = len(coco.anns) + support_idx, query_idx = self.get_subset_idx(num_data) + ann_ids = list(coco.anns) + + # prepare support subset + if self.num_shots > 0: + support_list = [] + for idx in support_idx: + support = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco) + support_list.append(support) + + # prepare query subset + query_list = [] + for idx in query_idx: + data_info = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco) + if self.num_shots > 0: + data_info = self.parse_fewshot_anno(data_info, support_list) + query_list.append(data_info) + + return query_list diff --git a/mmpretrain/datasets/flickr30k_caption.py b/mmpretrain/datasets/flickr30k_caption.py new file mode 100644 index 0000000..f0f6841 --- /dev/null +++ b/mmpretrain/datasets/flickr30k_caption.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class Flickr30kCaption(BaseDataset): + """Flickr30k Caption dataset. To generate coco-style GT annotation for + evaluation, please refer to + tools/dataset_converters/convert_flickr30k_ann.py. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str): Annotation file path for training and validation. + split (str): 'train', 'val' or 'test'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, ann_file: str, + split: str, **kwarg): + + assert split in ['train', 'val', 'test'], \ + '`split` must be train, val or test' + self.split = split + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + annotations = mmengine.load(self.ann_file) + file_backend = get_file_backend(img_prefix) + + data_list = [] + + for img in annotations['images']: + + # img_example={ + # "sentids": [0, 1, 2], + # "imgid": 0, + # "sentences": [ + # {"raw": "Two men in green shirts standing in a yard.", + # "imgid": 0, "sentid": 0}, + # {"raw": "A man in a blue shirt standing in a garden.", + # "imgid": 0, "sentid": 1}, + # {"raw": "Two friends enjoy time spent together.", + # "imgid": 0, "sentid": 2} + # ], + # "split": "train", + # "filename": "1000092795.jpg" + # }, + + if img['split'] != self.split: + continue + + for sentence in img['sentences']: + data_info = { + 'image_id': img['imgid'], + 'img_path': file_backend.join_path(img_prefix, + img['filename']), + 'gt_caption': sentence['raw'] + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/flickr30k_retrieval.py b/mmpretrain/datasets/flickr30k_retrieval.py new file mode 100644 index 0000000..9f43c15 --- /dev/null +++ b/mmpretrain/datasets/flickr30k_retrieval.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import List + +import mmengine +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class Flickr30kRetrieval(BaseDataset): + """Flickr30k Retrieval dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str): Annotation file path for training and validation. + split (str): 'train', 'val' or 'test'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, ann_file: str, + split: str, **kwarg): + + assert split in ['train', 'val', 'test'], \ + '`split` must be train, val or test' + self.split = split + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + # get file backend + img_prefix = self.data_prefix['img_path'] + file_backend = get_file_backend(img_prefix) + + annotations = mmengine.load(self.ann_file) + + # mapping img_id to img filename + img_dict = OrderedDict() + img_idx = 0 + sentence_idx = 0 + train_list = [] + for img in annotations['images']: + + # img_example={ + # "sentids": [0, 1, 2], + # "imgid": 0, + # "sentences": [ + # {"raw": "Two men in green shirts standing in a yard.", + # "imgid": 0, "sentid": 0}, + # {"raw": "A man in a blue shirt standing in a garden.", + # "imgid": 0, "sentid": 1}, + # {"raw": "Two friends enjoy time spent together.", + # "imgid": 0, "sentid": 2} + # ], + # "split": "train", + # "filename": "1000092795.jpg" + # }, + + if img['split'] != self.split: + continue + + # create new idx for image + train_image = dict( + ori_id=img['imgid'], + image_id=img_idx, # used for evaluation + img_path=file_backend.join_path(img_prefix, img['filename']), + text=[], + gt_text_id=[], + gt_image_id=[], + ) + + for sentence in img['sentences']: + ann = {} + ann['text'] = sentence['raw'] + ann['ori_id'] = sentence['sentid'] + ann['text_id'] = sentence_idx # used for evaluation + + ann['image_ori_id'] = train_image['ori_id'] + ann['image_id'] = train_image['image_id'] + ann['img_path'] = train_image['img_path'] + ann['is_matched'] = True + + # 1. prepare train data list item + train_list.append(ann) + # 2. prepare eval data list item based on img dict + train_image['text'].append(ann['text']) + train_image['gt_text_id'].append(ann['text_id']) + train_image['gt_image_id'].append(ann['image_id']) + + sentence_idx += 1 + + img_dict[img['imgid']] = train_image + img_idx += 1 + + self.img_size = len(img_dict) + self.text_size = len(train_list) + + # return needed format data list + if self.test_mode: + return list(img_dict.values()) + return train_list diff --git a/mmpretrain/datasets/flowers102.py b/mmpretrain/datasets/flowers102.py new file mode 100644 index 0000000..fe76dcc --- /dev/null +++ b/mmpretrain/datasets/flowers102.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mat4py +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class Flowers102(BaseDataset): + """The Oxford 102 Flower Dataset. + + Support the `Oxford 102 Flowers Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Flowers102 dataset directory: :: + + Flowers102 + ├── jpg + │ ├── image_00001.jpg + │ ├── image_00002.jpg + │ └── ... + ├── imagelabels.mat + ├── setid.mat + └── ... + + Args: + data_root (str): The root directory for Oxford 102 Flowers dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import Flowers102 + >>> train_dataset = Flowers102(data_root='data/Flowers102', split='trainval') + >>> train_dataset + Dataset Flowers102 + Number of samples: 2040 + Root of dataset: data/Flowers102 + >>> test_dataset = Flowers102(data_root='data/Flowers102', split='test') + >>> test_dataset + Dataset Flowers102 + Number of samples: 6149 + Root of dataset: data/Flowers102 + """ # noqa: E501 + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + ann_file = 'imagelabels.mat' + data_prefix = 'jpg' + train_test_split_file = 'setid.mat' + test_mode = split == 'test' + + self.backend = get_file_backend(data_root, enable_singleton=True) + + self.train_test_split_file = self.backend.join_path( + data_root, train_test_split_file) + + super(Flowers102, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + label_dict = mat4py.loadmat(self.ann_file)['labels'] + split_list = mat4py.loadmat(self.train_test_split_file) + + if self.split == 'train': + split_list = split_list['trnid'] + elif self.split == 'val': + split_list = split_list['valid'] + elif self.split == 'test': + split_list = split_list['tstid'] + else: + train_ids = split_list['trnid'] + val_ids = split_list['valid'] + train_ids.extend(val_ids) + split_list = train_ids + + data_list = [] + for sample_id in split_list: + img_name = 'image_%05d.jpg' % (sample_id) + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = int(label_dict[sample_id - 1]) - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/food101.py b/mmpretrain/datasets/food101.py new file mode 100644 index 0000000..4ce7ffe --- /dev/null +++ b/mmpretrain/datasets/food101.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import FOOD101_CATEGORIES + + +@DATASETS.register_module() +class Food101(BaseDataset): + """The Food101 Dataset. + + Support the `Food101 Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Food101 dataset directory: :: + + food-101 + ├── images + │ ├── class_x + │ │ ├── xx1.jpg + │ │ ├── xx2.jpg + │ │ └── ... + │ ├── class_y + │ │ ├── yy1.jpg + │ │ ├── yy2.jpg + │ │ └── ... + │ └── ... + ├── meta + │ ├── train.txt + │ └── test.txt + └── .... + + Args: + data_root (str): The root directory for Food101 dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import Food101 + >>> train_dataset = Food101(data_root='data/food-101', split='train') + >>> train_dataset + Dataset Food101 + Number of samples: 75750 + Number of categories: 101 + Root of dataset: data/food-101 + >>> test_dataset = Food101(data_root='data/food-101', split='test') + >>> test_dataset + Dataset Food101 + Number of samples: 25250 + Number of categories: 101 + Root of dataset: data/food-101 + """ # noqa: E501 + + METAINFO = {'classes': FOOD101_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + if split == 'train': + ann_file = self.backend.join_path('meta', 'train.txt') + else: + ann_file = self.backend.join_path('meta', 'test.txt') + + test_mode = split == 'test' + data_prefix = 'images' + + super(Food101, self).__init__( + ann_file=ann_file, + data_root=data_root, + test_mode=test_mode, + data_prefix=data_prefix, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + class_name, img_name = pair.split('/') + img_name = f'{img_name}.jpg' + img_path = self.backend.join_path(self.img_prefix, class_name, + img_name) + gt_label = self.METAINFO['classes'].index(class_name) + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/gqa_dataset.py b/mmpretrain/datasets/gqa_dataset.py new file mode 100644 index 0000000..741791b --- /dev/null +++ b/mmpretrain/datasets/gqa_dataset.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class GQA(BaseDataset): + """GQA dataset. + + We use the annotation file from LAVIS, and you can download all annotation files from following links: # noqa: E501 + + train: + https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json # noqa: E501 + val: + https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json # noqa: E501 + test: + https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json # noqa: E501 + + and images from the official website: + https://cs.stanford.edu/people/dorarad/gqa/index.html + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + data_list = [] + for ann in annotations: + # ann example + # { + # 'question': "Is it overcast?", + # 'answer': 'no, + # 'image_id': n161313.jpg, + # 'question_id': 262148000, + # .... + # } + data_info = dict() + data_info['img_path'] = osp.join(self.data_prefix['img_path'], + ann['image']) + data_info['question'] = ann['question'] + data_info['gt_answer'] = ann['answer'] + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/iconqa.py b/mmpretrain/datasets/iconqa.py new file mode 100644 index 0000000..20c4d87 --- /dev/null +++ b/mmpretrain/datasets/iconqa.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import list_dir_or_file +from mmengine.utils import check_file_exist + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class IconQA(BaseDataset): + """IconQA: A benchmark for abstract diagram understanding + and visual language reasoning. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of the specific task and split. + eg. ``iconqa/val/choose_text/``. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + sample_list = list( + list_dir_or_file(self.data_prefix['img_path'], list_file=False)) + + data_list = list() + for sample_id in sample_list: + # data json + # { + # "question": "How likely is it that you will pick a black one?", + # "choices": [ + # "certain", + # "unlikely", + # "impossible", + # "probable" + # ], + # "answer": 2, + # "ques_type": "choose_txt", + # "grade": "grade1", + # "label": "S2" + # } + data_info = mmengine.load( + mmengine.join_path(self.data_prefix['img_path'], sample_id, + 'data.json')) + data_info['gt_answer'] = data_info['choices'][int( + data_info['answer'])] + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], sample_id, 'image.png') + check_file_exist(data_info['img_path']) + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/imagenet.py b/mmpretrain/datasets/imagenet.py new file mode 100644 index 0000000..771d6ee --- /dev/null +++ b/mmpretrain/datasets/imagenet.py @@ -0,0 +1,235 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +from mmengine import fileio +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .categories import IMAGENET_CATEGORIES +from .custom import CustomDataset + + +@DATASETS.register_module() +class ImageNet(CustomDataset): + """`ImageNet `_ Dataset. + + The dataset supports two kinds of directory format, + + :: + + imagenet + ├── train + │ ├──class_x + | | ├── x1.jpg + | | ├── x2.jpg + | | └── ... + │ ├── class_y + | | ├── y1.jpg + | | ├── y2.jpg + | | └── ... + | └── ... + ├── val + │ ├──class_x + | | └── ... + │ ├── class_y + | | └── ... + | └── ... + └── test + ├── test1.jpg + ├── test2.jpg + └── ... + + or :: + + imagenet + ├── train + │ ├── x1.jpg + │ ├── y1.jpg + │ └── ... + ├── val + │ ├── x3.jpg + │ ├── y3.jpg + │ └── ... + ├── test + │ ├── test1.jpg + │ ├── test2.jpg + │ └── ... + └── meta + ├── train.txt + └── val.txt + + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + split (str): The dataset split, supports "train", "val" and "test". + Default to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + **kwargs: Other keyword arguments in :class:`CustomDataset` and + :class:`BaseDataset`. + + + Examples: + >>> from mmpretrain.datasets import ImageNet + >>> train_dataset = ImageNet(data_root='data/imagenet', split='train') + >>> train_dataset + Dataset ImageNet + Number of samples: 1281167 + Number of categories: 1000 + Root of dataset: data/imagenet + >>> test_dataset = ImageNet(data_root='data/imagenet', split='val') + >>> test_dataset + Dataset ImageNet + Number of samples: 50000 + Number of categories: 1000 + Root of dataset: data/imagenet + """ # noqa: E501 + + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') + METAINFO = {'classes': IMAGENET_CATEGORIES} + + def __init__(self, + data_root: str = '', + split: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + metainfo: Optional[dict] = None, + **kwargs): + kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs} + + if split: + splits = ['train', 'val', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + + if split == 'test': + logger = MMLogger.get_current_instance() + logger.info( + 'Since the ImageNet1k test set does not provide label' + 'annotations, `with_label` is set to False') + kwargs['with_label'] = False + + data_prefix = split if data_prefix == '' else data_prefix + + if ann_file == '': + _ann_path = fileio.join_path(data_root, 'meta', f'{split}.txt') + if fileio.exists(_ann_path): + ann_file = fileio.join_path('meta', f'{split}.txt') + + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + metainfo=metainfo, + **kwargs) + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body + + +@DATASETS.register_module() +class ImageNet21k(CustomDataset): + """ImageNet21k Dataset. + + Since the dataset ImageNet21k is extremely big, contains 21k+ classes + and 1.4B files. We won't provide the default categories list. Please + specify it from the ``classes`` argument. + The dataset directory structure is as follows, + + ImageNet21k dataset directory :: + + imagenet21k + ├── train + │ ├──class_x + | | ├── x1.jpg + | | ├── x2.jpg + | | └── ... + │ ├── class_y + | | ├── y1.jpg + | | ├── y2.jpg + | | └── ... + | └── ... + └── meta + └── train.txt + + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + multi_label (bool): Not implement by now. Use multi label or not. + Defaults to False. + **kwargs: Other keyword arguments in :class:`CustomDataset` and + :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import ImageNet21k + >>> train_dataset = ImageNet21k(data_root='data/imagenet21k', split='train') + >>> train_dataset + Dataset ImageNet21k + Number of samples: 14197088 + Annotation file: data/imagenet21k/meta/train.txt + Prefix of images: data/imagenet21k/train + """ # noqa: E501 + + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') + + def __init__(self, + data_root: str = '', + split: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + metainfo: Optional[dict] = None, + multi_label: bool = False, + **kwargs): + if multi_label: + raise NotImplementedError( + 'The `multi_label` option is not supported by now.') + self.multi_label = multi_label + + if split: + splits = ['train'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'.\ + If you want to specify your own validation set or test set,\ + please set split to None." + + self.split = split + data_prefix = split if data_prefix == '' else data_prefix + + if not ann_file: + _ann_path = fileio.join_path(data_root, 'meta', f'{split}.txt') + if fileio.exists(_ann_path): + ann_file = fileio.join_path('meta', f'{split}.txt') + + logger = MMLogger.get_current_instance() + + if not ann_file: + logger.warning( + 'The ImageNet21k dataset is large, and scanning directory may ' + 'consume long time. Considering to specify the `ann_file` to ' + 'accelerate the initialization.') + + kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs} + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + metainfo=metainfo, + **kwargs) + + if self.CLASSES is None: + logger.warning( + 'The CLASSES is not stored in the `ImageNet21k` class. ' + 'Considering to specify the `classes` argument if you need ' + 'do inference on the ImageNet-21k dataset') diff --git a/mmpretrain/datasets/infographic_vqa.py b/mmpretrain/datasets/infographic_vqa.py new file mode 100644 index 0000000..46f5b0a --- /dev/null +++ b/mmpretrain/datasets/infographic_vqa.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class InfographicVQA(BaseDataset): + """Infographic VQA dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + annotations = annotations['data'] + + data_list = [] + for ann in annotations: + # ann example + # { + # "questionId": 98313, + # "question": "Which social platform has heavy female audience?", + # "image_local_name": "37313.jpeg", + # "image_url": "https://xxx.png", + # "ocr_output_file": "37313.json", + # "answers": [ + # "pinterest" + # ], + # "data_split": "val" + # } + data_info = dict() + data_info['question'] = ann['question'] + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], ann['image_local_name']) + if 'answers' in ann.keys(): # test splits do not include gt + data_info['gt_answer'] = ann['answers'] + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/inshop.py b/mmpretrain/datasets/inshop.py new file mode 100644 index 0000000..f64f177 --- /dev/null +++ b/mmpretrain/datasets/inshop.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class InShop(BaseDataset): + """InShop Dataset for Image Retrieval. + + Please download the images from the homepage + 'https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html' + (In-shop Clothes Retrieval Benchmark -> Img -> img.zip, + Eval/list_eval_partition.txt), and organize them as follows way: :: + + In-shop Clothes Retrieval Benchmark (data_root)/ + ├── Eval / + │ └── list_eval_partition.txt (ann_file) + ├── Img (img_prefix) + │ └── img/ + ├── README.txt + └── ..... + + Args: + data_root (str): The root directory for dataset. + split (str): Choose from 'train', 'query' and 'gallery'. + Defaults to 'train'. + data_prefix (str | dict): Prefix for training data. + Defaults to 'Img'. + ann_file (str): Annotation file path, path relative to + ``data_root``. Defaults to 'Eval/list_eval_partition.txt'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import InShop + >>> + >>> # build train InShop dataset + >>> inshop_train_cfg = dict(data_root='data/inshop', split='train') + >>> inshop_train = InShop(**inshop_train_cfg) + >>> inshop_train + Dataset InShop + Number of samples: 25882 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + >>> + >>> # build query InShop dataset + >>> inshop_query_cfg = dict(data_root='data/inshop', split='query') + >>> inshop_query = InShop(**inshop_query_cfg) + >>> inshop_query + Dataset InShop + Number of samples: 14218 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + >>> + >>> # build gallery InShop dataset + >>> inshop_gallery_cfg = dict(data_root='data/inshop', split='gallery') + >>> inshop_gallery = InShop(**inshop_gallery_cfg) + >>> inshop_gallery + Dataset InShop + Number of samples: 12612 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + """ + + def __init__(self, + data_root: str, + split: str = 'train', + data_prefix: str = 'Img', + ann_file: str = 'Eval/list_eval_partition.txt', + **kwargs): + + assert split in ('train', 'query', 'gallery'), "'split' of `InShop`" \ + f" must be one of ['train', 'query', 'gallery'], bu get '{split}'" + self.backend = get_file_backend(data_root, enable_singleton=True) + self.split = split + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + **kwargs) + + def _process_annotations(self): + lines = list_from_file(self.ann_file) + + anno_train = dict(metainfo=dict(), data_list=list()) + anno_gallery = dict(metainfo=dict(), data_list=list()) + + # item_id to label, each item corresponds to one class label + class_num = 0 + gt_label_train = {} + + # item_id to label, each label corresponds to several items + gallery_num = 0 + gt_label_gallery = {} + + # (lines[0], lines[1]) is the image number and the field name; + # Each line format as 'image_name, item_id, evaluation_status' + for line in lines[2:]: + img_name, item_id, status = line.split() + img_path = self.backend.join_path(self.img_prefix, img_name) + if status == 'train': + if item_id not in gt_label_train: + gt_label_train[item_id] = class_num + class_num += 1 + # item_id to class_id (for the training set) + anno_train['data_list'].append( + dict(img_path=img_path, gt_label=gt_label_train[item_id])) + elif status == 'gallery': + if item_id not in gt_label_gallery: + gt_label_gallery[item_id] = [] + # Since there are multiple images for each item, + # record the corresponding item for each image. + gt_label_gallery[item_id].append(gallery_num) + anno_gallery['data_list'].append( + dict(img_path=img_path, sample_idx=gallery_num)) + gallery_num += 1 + + if self.split == 'train': + anno_train['metainfo']['class_number'] = class_num + anno_train['metainfo']['sample_number'] = \ + len(anno_train['data_list']) + return anno_train + elif self.split == 'gallery': + anno_gallery['metainfo']['sample_number'] = gallery_num + return anno_gallery + + # Generate the label for the query(val) set + anno_query = dict(metainfo=dict(), data_list=list()) + query_num = 0 + for line in lines[2:]: + img_name, item_id, status = line.split() + img_path = self.backend.join_path(self.img_prefix, img_name) + if status == 'query': + anno_query['data_list'].append( + dict( + img_path=img_path, gt_label=gt_label_gallery[item_id])) + query_num += 1 + + anno_query['metainfo']['sample_number'] = query_num + return anno_query + + def load_data_list(self): + """load data list. + + For the train set, return image and ground truth label. For the query + set, return image and ids of images in gallery. For the gallery set, + return image and its id. + """ + data_info = self._process_annotations() + data_list = data_info['data_list'] + return data_list + + def extra_repr(self): + """The extra repr information of the dataset.""" + body = [f'Root of dataset: \t{self.data_root}'] + return body diff --git a/mmpretrain/datasets/minigpt4_dataset.py b/mmpretrain/datasets/minigpt4_dataset.py new file mode 100644 index 0000000..e14e5c3 --- /dev/null +++ b/mmpretrain/datasets/minigpt4_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class MiniGPT4Dataset(BaseDataset): + """Dataset for training MiniGPT4. + + MiniGPT4 dataset directory: + + minigpt4_dataset + ├── image + │ ├── id0.jpg + │ │── id1.jpg + │ │── id2.jpg + │ └── ... + └── conversation_data.json + + The structure of conversation_data.json: + + [ + // English data + { + "id": str(id0), + "conversation": "###Ask: [Ask content] + ###Answer: [Answer content]" + }, + + // Chinese data + { + "id": str(id1), + "conversation": "###问: [Ask content] + ###答:[Answer content]" + }, + + ... + ] + + Args: + data_root (str): The root directory for ``ann_file`` and ``image``. + ann_file (str): Conversation file path. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def load_data_list(self) -> List[dict]: + file_backend = get_file_backend(self.data_root) + conversation_path = file_backend.join_path(self.data_root, + self.ann_file) + conversation = mmengine.load(conversation_path) + img_ids = {} + n = 0 + for conv in conversation: + img_id = conv['id'] + if img_id not in img_ids.keys(): + img_ids[img_id] = n + n += 1 + + img_root = file_backend.join_path(self.data_root, 'image') + data_list = [] + for conv in conversation: + img_file = '{}.jpg'.format(conv['id']) + chat_content = conv['conversation'] + lang = 'en' if chat_content.startswith('###Ask: ') else 'zh' + data_info = { + 'image_id': img_ids[conv['id']], + 'img_path': file_backend.join_path(img_root, img_file), + 'chat_content': chat_content, + 'lang': lang, + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/mnist.py b/mmpretrain/datasets/mnist.py new file mode 100644 index 0000000..425267f --- /dev/null +++ b/mmpretrain/datasets/mnist.py @@ -0,0 +1,234 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import codecs +from typing import List, Optional +from urllib.parse import urljoin + +import mmengine.dist as dist +import numpy as np +import torch +from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES +from .utils import (download_and_extract_archive, open_maybe_compressed_file, + rm_suffix) + + +@DATASETS.register_module() +class MNIST(BaseDataset): + """`MNIST `_ Dataset. + + This implementation is modified from + https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py + + Args: + data_root (str): The root directory of the MNIST Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ # noqa: E501 + + url_prefix = 'http://yann.lecun.com/exdb/mnist/' + # train images and labels + train_list = [ + ['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'], + ['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'], + ] + # test images and labels + test_list = [ + ['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'], + ['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'], + ] + METAINFO = {'classes': MNIST_CATEGORITES} + + def __init__(self, + data_root: str = '', + split: str = 'train', + metainfo: Optional[dict] = None, + download: bool = True, + data_prefix: str = '', + test_mode: bool = False, + **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + # To handle the BC-breaking + if split == 'train' and test_mode: + logger = MMLogger.get_current_instance() + logger.warning('split="train" but test_mode=True. ' + 'The training set will be used.') + + if not data_root and not data_prefix: + raise RuntimeError('Please set ``data_root`` to' + 'specify the dataset path') + + self.download = download + super().__init__( + # The MNIST dataset doesn't need specify annotation file + ann_file='', + metainfo=metainfo, + data_root=data_root, + data_prefix=dict(root=data_prefix), + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + root = self.data_prefix['root'] + backend = get_file_backend(root, enable_singleton=True) + + if dist.is_main_process() and not self._check_exists(): + if not isinstance(backend, LocalBackend): + raise RuntimeError(f'The dataset on {root} is not integrated, ' + f'please manually handle it.') + + if self.download: + self._download() + else: + raise RuntimeError( + f'Cannot find {self.__class__.__name__} dataset in ' + f"{self.data_prefix['root']}, you can specify " + '`download=True` to download automatically.') + + dist.barrier() + assert self._check_exists(), \ + 'Download failed or shared storage is unavailable. Please ' \ + f'download the dataset manually through {self.url_prefix}.' + + if not self.test_mode: + file_list = self.train_list + else: + file_list = self.test_list + + # load data from SN3 files + imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0]))) + gt_labels = read_label_file( + join_path(root, rm_suffix(file_list[1][0]))) + + data_infos = [] + for img, gt_label in zip(imgs, gt_labels): + gt_label = np.array(gt_label, dtype=np.int64) + info = {'img': img.numpy(), 'gt_label': gt_label} + data_infos.append(info) + return data_infos + + def _check_exists(self): + """Check the exists of data files.""" + root = self.data_prefix['root'] + + for filename, _ in (self.train_list + self.test_list): + # get extracted filename of data + extract_filename = rm_suffix(filename) + fpath = join_path(root, extract_filename) + if not exists(fpath): + return False + return True + + def _download(self): + """Download and extract data files.""" + root = self.data_prefix['root'] + + for filename, md5 in (self.train_list + self.test_list): + url = urljoin(self.url_prefix, filename) + download_and_extract_archive( + url, download_root=root, filename=filename, md5=md5) + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [f"Prefix of data: \t{self.data_prefix['root']}"] + return body + + +@DATASETS.register_module() +class FashionMNIST(MNIST): + """`Fashion-MNIST `_ + Dataset. + + Args: + data_root (str): The root directory of the MNIST Dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + download (bool): Whether to download the dataset if not exists. + Defaults to True. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' + # train images and labels + train_list = [ + ['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'], + ['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'], + ] + # test images and labels + test_list = [ + ['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'], + ['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'], + ] + METAINFO = {'classes': FASHIONMNIST_CATEGORITES} + + +def get_int(b: bytes) -> int: + """Convert bytes to int.""" + return int(codecs.encode(b, 'hex'), 16) + + +def read_sn3_pascalvincent_tensor(path: str, + strict: bool = True) -> torch.Tensor: + """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx- + io.lsh'). + + Argument may be a filename, compressed filename, or file object. + """ + # typemap + if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'): + read_sn3_pascalvincent_tensor.typemap = { + 8: (torch.uint8, np.uint8, np.uint8), + 9: (torch.int8, np.int8, np.int8), + 11: (torch.int16, np.dtype('>i2'), 'i2'), + 12: (torch.int32, np.dtype('>i4'), 'i4'), + 13: (torch.float32, np.dtype('>f4'), 'f4'), + 14: (torch.float64, np.dtype('>f8'), 'f8') + } + # read + with open_maybe_compressed_file(path) as f: + data = f.read() + # parse + magic = get_int(data[0:4]) + nd = magic % 256 + ty = magic // 256 + assert nd >= 1 and nd <= 3 + assert ty >= 8 and ty <= 14 + m = read_sn3_pascalvincent_tensor.typemap[ty] + s = [get_int(data[4 * (i + 1):4 * (i + 2)]) for i in range(nd)] + parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) + assert parsed.shape[0] == np.prod(s) or not strict + return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) + + +def read_label_file(path: str) -> torch.Tensor: + """Read labels from SN3 label file.""" + with open(path, 'rb') as f: + x = read_sn3_pascalvincent_tensor(f, strict=False) + assert (x.dtype == torch.uint8) + assert (x.ndimension() == 1) + return x.long() + + +def read_image_file(path: str) -> torch.Tensor: + """Read images from SN3 image file.""" + with open(path, 'rb') as f: + x = read_sn3_pascalvincent_tensor(f, strict=False) + assert (x.dtype == torch.uint8) + assert (x.ndimension() == 3) + return x diff --git a/mmpretrain/datasets/multi_label.py b/mmpretrain/datasets/multi_label.py new file mode 100644 index 0000000..58a9c7c --- /dev/null +++ b/mmpretrain/datasets/multi_label.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class MultiLabelDataset(BaseDataset): + """Multi-label Dataset. + + This dataset support annotation file in `OpenMMLab 2.0 style annotation + format`. + + The annotation format is shown as follows. + + .. code-block:: none + + { + "metainfo": + { + "classes":['A', 'B', 'C'....] + }, + "data_list": + [ + { + "img_path": "test_img1.jpg", + 'gt_label': [0, 1], + }, + { + "img_path": "test_img2.jpg", + 'gt_label': [2], + }, + ] + .... + } + + + Args: + ann_file (str): Annotation file path. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + classes (str | Sequence[str], optional): Specify names of classes. + + - If is string, it should be a file path, and the every line of + the file is a name of a class. + - If is a sequence of string, every item is a name of class. + - If is None, use categories information in ``metainfo`` argument, + annotation file or the class attribute ``METAINFO``. + + Defaults to None. + """ + + def get_cat_ids(self, idx: int) -> List[int]: + """Get category ids by index. + + Args: + idx (int): Index of data. + + Returns: + cat_ids (List[int]): Image categories of specified index. + """ + return self.get_data_info(idx)['gt_label'] diff --git a/mmpretrain/datasets/multi_task.py b/mmpretrain/datasets/multi_task.py new file mode 100644 index 0000000..443df0e --- /dev/null +++ b/mmpretrain/datasets/multi_task.py @@ -0,0 +1,337 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from os import PathLike +from typing import Optional, Sequence + +import mmengine +from mmcv.transforms import Compose +from mmengine.fileio import get_file_backend + +from .builder import DATASETS + + +def expanduser(path): + if isinstance(path, (str, PathLike)): + return osp.expanduser(path) + else: + return path + + +def isabs(uri): + return osp.isabs(uri) or ('://' in uri) + + +@DATASETS.register_module() +class MultiTaskDataset: + """Custom dataset for multi-task dataset. + + To use the dataset, please generate and provide an annotation file in the + below format: + + .. code-block:: json + + { + "metainfo": { + "tasks": + [ + 'gender' + 'wear' + ] + }, + "data_list": [ + { + "img_path": "a.jpg", + gt_label:{ + "gender": 0, + "wear": [1, 0, 1, 0] + } + }, + { + "img_path": "b.jpg", + gt_label:{ + "gender": 1, + "wear": [1, 0, 1, 0] + } + } + ] + } + + Assume we put our dataset in the ``data/mydataset`` folder in the + repository and organize it as the below format: :: + + mmpretrain/ + └── data + └── mydataset + ├── annotation + │   ├── train.json + │   ├── test.json + │   └── val.json + ├── train + │   ├── a.jpg + │   └── ... + ├── test + │   ├── b.jpg + │   └── ... + └── val + ├── c.jpg + └── ... + + We can use the below config to build datasets: + + .. code:: python + + >>> from mmpretrain.datasets import build_dataset + >>> train_cfg = dict( + ... type="MultiTaskDataset", + ... ann_file="annotation/train.json", + ... data_root="data/mydataset", + ... # The `img_path` field in the train annotation file is relative + ... # to the `train` folder. + ... data_prefix='train', + ... ) + >>> train_dataset = build_dataset(train_cfg) + + Or we can put all files in the same folder: :: + + mmpretrain/ + └── data + └── mydataset + ├── train.json + ├── test.json + ├── val.json + ├── a.jpg + ├── b.jpg + ├── c.jpg + └── ... + + And we can use the below config to build datasets: + + .. code:: python + + >>> from mmpretrain.datasets import build_dataset + >>> train_cfg = dict( + ... type="MultiTaskDataset", + ... ann_file="train.json", + ... data_root="data/mydataset", + ... # the `data_prefix` is not required since all paths are + ... # relative to the `data_root`. + ... ) + >>> train_dataset = build_dataset(train_cfg) + + + Args: + ann_file (str): The annotation file path. It can be either absolute + path or relative path to the ``data_root``. + metainfo (dict, optional): The extra meta information. It should be + a dict with the same format as the ``"metainfo"`` field in the + annotation file. Defaults to None. + data_root (str, optional): The root path of the data directory. It's + the prefix of the ``data_prefix`` and the ``ann_file``. And it can + be a remote path like "s3://openmmlab/xxx/". Defaults to None. + data_prefix (str, optional): The base folder relative to the + ``data_root`` for the ``"img_path"`` field in the annotation file. + Defaults to None. + pipeline (Sequence[dict]): A list of dict, where each element + represents a operation defined in + :mod:`mmpretrain.datasets.pipelines`. Defaults to an empty tuple. + test_mode (bool): in train mode or test mode. Defaults to False. + """ + METAINFO = dict() + + def __init__(self, + ann_file: str, + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: Optional[str] = None, + pipeline: Sequence = (), + test_mode: bool = False): + + self.data_root = expanduser(data_root) + + # Inference the file client + if self.data_root is not None: + self.file_backend = get_file_backend(uri=self.data_root) + else: + self.file_backend = None + + self.ann_file = self._join_root(expanduser(ann_file)) + self.data_prefix = self._join_root(data_prefix) + + self.test_mode = test_mode + self.pipeline = Compose(pipeline) + self.data_list = self.load_data_list(self.ann_file, metainfo) + + def _join_root(self, path): + """Join ``self.data_root`` with the specified path. + + If the path is an absolute path, just return the path. And if the + path is None, return ``self.data_root``. + + Examples: + >>> self.data_root = 'a/b/c' + >>> self._join_root('d/e/') + 'a/b/c/d/e' + >>> self._join_root('https://openmmlab.com') + 'https://openmmlab.com' + >>> self._join_root(None) + 'a/b/c' + """ + if path is None: + return self.data_root + if isabs(path): + return path + + joined_path = self.file_backend.join_path(self.data_root, path) + return joined_path + + @classmethod + def _get_meta_info(cls, in_metainfo: dict = None) -> dict: + """Collect meta information from the dictionary of meta. + + Args: + in_metainfo (dict): Meta information dict. + + Returns: + dict: Parsed meta information. + """ + # `cls.METAINFO` will be overwritten by in_meta + metainfo = copy.deepcopy(cls.METAINFO) + if in_metainfo is None: + return metainfo + + metainfo.update(in_metainfo) + + return metainfo + + def load_data_list(self, ann_file, metainfo_override=None): + """Load annotations from an annotation file. + + Args: + ann_file (str): Absolute annotation file path if ``self.root=None`` + or relative path if ``self.root=/path/to/data/``. + + Returns: + list[dict]: A list of annotation. + """ + annotations = mmengine.load(ann_file) + if not isinstance(annotations, dict): + raise TypeError(f'The annotations loaded from annotation file ' + f'should be a dict, but got {type(annotations)}!') + if 'data_list' not in annotations: + raise ValueError('The annotation file must have the `data_list` ' + 'field.') + metainfo = annotations.get('metainfo', {}) + raw_data_list = annotations['data_list'] + + # Set meta information. + assert isinstance(metainfo, dict), 'The `metainfo` field in the '\ + f'annotation file should be a dict, but got {type(metainfo)}' + if metainfo_override is not None: + assert isinstance(metainfo_override, dict), 'The `metainfo` ' \ + f'argument should be a dict, but got {type(metainfo_override)}' + metainfo.update(metainfo_override) + self._metainfo = self._get_meta_info(metainfo) + + data_list = [] + for i, raw_data in enumerate(raw_data_list): + try: + data_list.append(self.parse_data_info(raw_data)) + except AssertionError as e: + raise RuntimeError( + f'The format check fails during parse the item {i} of ' + f'the annotation file with error: {e}') + return data_list + + def parse_data_info(self, raw_data): + """Parse raw annotation to target format. + + This method will return a dict which contains the data information of a + sample. + + Args: + raw_data (dict): Raw data information load from ``ann_file`` + + Returns: + dict: Parsed annotation. + """ + assert isinstance(raw_data, dict), \ + f'The item should be a dict, but got {type(raw_data)}' + assert 'img_path' in raw_data, \ + "The item doesn't have `img_path` field." + data = dict( + img_path=self._join_root(raw_data['img_path']), + gt_label=raw_data['gt_label'], + ) + return data + + @property + def metainfo(self) -> dict: + """Get meta information of dataset. + + Returns: + dict: meta information collected from ``cls.METAINFO``, + annotation file and metainfo argument during instantiation. + """ + return copy.deepcopy(self._metainfo) + + def prepare_data(self, idx): + """Get data processed by ``self.pipeline``. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + results = copy.deepcopy(self.data_list[idx]) + return self.pipeline(results) + + def __len__(self): + """Get the length of the whole dataset. + + Returns: + int: The length of filtered dataset. + """ + return len(self.data_list) + + def __getitem__(self, idx): + """Get the idx-th image and data information of dataset after + ``self.pipeline``. + + Args: + idx (int): The index of of the data. + + Returns: + dict: The idx-th image and data information after + ``self.pipeline``. + """ + return self.prepare_data(idx) + + def __repr__(self): + """Print the basic information of the dataset. + + Returns: + str: Formatted string. + """ + head = 'Dataset ' + self.__class__.__name__ + body = [f'Number of samples: \t{self.__len__()}'] + if self.data_root is not None: + body.append(f'Root location: \t{self.data_root}') + body.append(f'Annotation file: \t{self.ann_file}') + if self.data_prefix is not None: + body.append(f'Prefix of images: \t{self.data_prefix}') + # -------------------- extra repr -------------------- + tasks = self.metainfo['tasks'] + body.append(f'For {len(tasks)} tasks') + for task in tasks: + body.append(f' {task} ') + # ---------------------------------------------------- + + if len(self.pipeline.transforms) > 0: + body.append('With transforms:') + for t in self.pipeline.transforms: + body.append(f' {t}') + + lines = [head] + [' ' * 4 + line for line in body] + return '\n'.join(lines) diff --git a/mmpretrain/datasets/nlvr2.py b/mmpretrain/datasets/nlvr2.py new file mode 100644 index 0000000..0063090 --- /dev/null +++ b/mmpretrain/datasets/nlvr2.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +from typing import List + +from mmengine.fileio import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class NLVR2(BaseDataset): + """COCO Caption dataset.""" + + def load_data_list(self) -> List[dict]: + """Load data list.""" + + data_list = [] + img_prefix = self.data_prefix['img_path'] + file_backend = get_file_backend(img_prefix) + examples = list_from_file(self.ann_file) + + for example in examples: + example = json.loads(example) + prefix = example['identifier'].rsplit('-', 1)[0] + train_data = {} + train_data['text'] = example['sentence'] + train_data['gt_label'] = {'True': 1, 'False': 0}[example['label']] + train_data['img_path'] = [ + file_backend.join_path(img_prefix, prefix + f'-img{i}.png') + for i in range(2) + ] + + data_list.append(train_data) + + return data_list diff --git a/mmpretrain/datasets/nocaps.py b/mmpretrain/datasets/nocaps.py new file mode 100644 index 0000000..65116e9 --- /dev/null +++ b/mmpretrain/datasets/nocaps.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend +from pycocotools.coco import COCO + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class NoCaps(BaseDataset): + """NoCaps dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``.. + ann_file (str): Annotation file path. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + with mmengine.get_local_path(self.ann_file) as ann_file: + coco = COCO(ann_file) + + file_backend = get_file_backend(img_prefix) + data_list = [] + for ann in coco.anns.values(): + image_id = ann['image_id'] + image_path = file_backend.join_path( + img_prefix, coco.imgs[image_id]['file_name']) + data_info = { + 'image_id': image_id, + 'img_path': image_path, + 'gt_caption': None + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/ocr_vqa.py b/mmpretrain/datasets/ocr_vqa.py new file mode 100644 index 0000000..55aa691 --- /dev/null +++ b/mmpretrain/datasets/ocr_vqa.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class OCRVQA(BaseDataset): + """OCR-VQA dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str): Annotation file path for training and validation. + split (str): 'train', 'val' or 'test'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, ann_file: str, + split: str, **kwarg): + + assert split in ['train', 'val', 'test'], \ + '`split` must be train, val or test' + self.split = split + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + + split_dict = {1: 'train', 2: 'val', 3: 'test'} + + annotations = mmengine.load(self.ann_file) + + # ann example + # "761183272": { + # "imageURL": \ + # "http://ecx.images-amazon.com/images/I/61Y5cOdHJbL.jpg", + # "questions": [ + # "Who wrote this book?", + # "What is the title of this book?", + # "What is the genre of this book?", + # "Is this a games related book?", + # "What is the year printed on this calendar?"], + # "answers": [ + # "Sandra Boynton", + # "Mom's Family Wall Calendar 2016", + # "Calendars", + # "No", + # "2016"], + # "title": "Mom's Family Wall Calendar 2016", + # "authorName": "Sandra Boynton", + # "genre": "Calendars", + # "split": 1 + # }, + + data_list = [] + + for key, ann in annotations.items(): + if self.split != split_dict[ann['split']]: + continue + + extension = osp.splitext(ann['imageURL'])[1] + if extension not in ['.jpg', '.png']: + continue + img_path = mmengine.join_path(self.data_prefix['img_path'], + key + extension) + for question, answer in zip(ann['questions'], ann['answers']): + data_info = {} + data_info['img_path'] = img_path + data_info['question'] = question + data_info['gt_answer'] = answer + data_info['gt_answer_weight'] = [1.0] + + data_info['imageURL'] = ann['imageURL'] + data_info['title'] = ann['title'] + data_info['authorName'] = ann['authorName'] + data_info['genre'] = ann['genre'] + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/oxfordiiitpet.py b/mmpretrain/datasets/oxfordiiitpet.py new file mode 100644 index 0000000..23c8b7d --- /dev/null +++ b/mmpretrain/datasets/oxfordiiitpet.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import OxfordIIITPet_CATEGORIES + + +@DATASETS.register_module() +class OxfordIIITPet(BaseDataset): + """The Oxford-IIIT Pets Dataset. + + Support the `Oxford-IIIT Pets Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + Oxford-IIIT_Pets dataset directory: :: + + Oxford-IIIT_Pets + ├── images + │ ├── Abyssinian_1.jpg + │ ├── Abyssinian_2.jpg + │ └── ... + ├── annotations + │ ├── trainval.txt + │ ├── test.txt + │ ├── list.txt + │ └── ... + └── .... + + Args: + data_root (str): The root directory for Oxford-IIIT Pets dataset. + split (str, optional): The dataset split, supports "trainval" and "test". + Default to "trainval". + + Examples: + >>> from mmpretrain.datasets import OxfordIIITPet + >>> train_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='trainval') + >>> train_dataset + Dataset OxfordIIITPet + Number of samples: 3680 + Number of categories: 37 + Root of dataset: data/Oxford-IIIT_Pets + >>> test_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='test') + >>> test_dataset + Dataset OxfordIIITPet + Number of samples: 3669 + Number of categories: 37 + Root of dataset: data/Oxford-IIIT_Pets + """ # noqa: E501 + + METAINFO = {'classes': OxfordIIITPet_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'trainval', **kwargs): + + splits = ['trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + if split == 'trainval': + ann_file = self.backend.join_path('annotations', 'trainval.txt') + else: + ann_file = self.backend.join_path('annotations', 'test.txt') + + data_prefix = 'images' + test_mode = split == 'test' + + super(OxfordIIITPet, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + """Load images and ground truth labels.""" + + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + img_name, class_id, _, _ = pair.split() + img_name = f'{img_name}.jpg' + img_path = self.backend.join_path(self.img_prefix, img_name) + gt_label = int(class_id) - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/places205.py b/mmpretrain/datasets/places205.py new file mode 100644 index 0000000..f3ba1ff --- /dev/null +++ b/mmpretrain/datasets/places205.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +from mmpretrain.registry import DATASETS +from .categories import PLACES205_CATEGORIES +from .custom import CustomDataset + + +@DATASETS.register_module() +class Places205(CustomDataset): + """`Places205 `_ Dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults + to ''. + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + **kwargs: Other keyword arguments in :class:`CustomDataset` and + :class:`BaseDataset`. + """ + + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') + METAINFO = {'classes': PLACES205_CATEGORIES} + + def __init__(self, + data_root: str = '', + data_prefix: Union[str, dict] = '', + ann_file: str = '', + metainfo: Optional[dict] = None, + **kwargs): + kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs} + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + metainfo=metainfo, + **kwargs) diff --git a/mmpretrain/datasets/refcoco.py b/mmpretrain/datasets/refcoco.py new file mode 100644 index 0000000..39c3d3e --- /dev/null +++ b/mmpretrain/datasets/refcoco.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import mmengine +import numpy as np +from mmengine.dataset import BaseDataset +from pycocotools.coco import COCO + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class RefCOCO(BaseDataset): + """RefCOCO dataset. + + RefCOCO is a popular dataset used for the task of visual grounding. + Here are the steps for accessing and utilizing the + RefCOCO dataset. + + You can access the RefCOCO dataset from the official source: + https://github.com/lichengunc/refer + + The RefCOCO dataset is organized in a structured format: :: + + FeaturesDict({ + 'coco_annotations': Sequence({ + 'area': int64, + 'bbox': BBoxFeature(shape=(4,), dtype=float32), + 'id': int64, + 'label': int64, + }), + 'image': Image(shape=(None, None, 3), dtype=uint8), + 'image/id': int64, + 'objects': Sequence({ + 'area': int64, + 'bbox': BBoxFeature(shape=(4,), dtype=float32), + 'gt_box_index': int64, + 'id': int64, + 'label': int64, + 'refexp': Sequence({ + 'raw': Text(shape=(), dtype=string), + 'refexp_id': int64, + }), + }), + }) + + Args: + ann_file (str): Annotation file path. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str): Prefix for training data. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root, + ann_file, + data_prefix, + split_file, + split='train', + **kwargs): + self.split_file = split_file + self.split = split + + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwargs, + ) + + def _join_prefix(self): + if not mmengine.is_abs(self.split_file) and self.split_file: + self.split_file = osp.join(self.data_root, self.split_file) + + return super()._join_prefix() + + def load_data_list(self) -> List[dict]: + """Load data list.""" + with mmengine.get_local_path(self.ann_file) as ann_file: + coco = COCO(ann_file) + splits = mmengine.load(self.split_file, file_format='pkl') + img_prefix = self.data_prefix['img_path'] + + data_list = [] + join_path = mmengine.fileio.get_file_backend(img_prefix).join_path + for refer in splits: + if refer['split'] != self.split: + continue + + ann = coco.anns[refer['ann_id']] + img = coco.imgs[ann['image_id']] + sentences = refer['sentences'] + bbox = np.array(ann['bbox'], dtype=np.float32) + bbox[2:4] = bbox[0:2] + bbox[2:4] # XYWH -> XYXY + + for sent in sentences: + data_info = { + 'img_path': join_path(img_prefix, img['file_name']), + 'image_id': ann['image_id'], + 'ann_id': ann['id'], + 'text': sent['sent'], + 'gt_bboxes': bbox[None, :], + } + data_list.append(data_info) + + if len(data_list) == 0: + raise ValueError(f'No sample in split "{self.split}".') + + return data_list diff --git a/mmpretrain/datasets/samplers/__init__.py b/mmpretrain/datasets/samplers/__init__.py new file mode 100644 index 0000000..2bccf9c --- /dev/null +++ b/mmpretrain/datasets/samplers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .repeat_aug import RepeatAugSampler +from .sequential import SequentialSampler + +__all__ = ['RepeatAugSampler', 'SequentialSampler'] diff --git a/mmpretrain/datasets/samplers/repeat_aug.py b/mmpretrain/datasets/samplers/repeat_aug.py new file mode 100644 index 0000000..d833a19 --- /dev/null +++ b/mmpretrain/datasets/samplers/repeat_aug.py @@ -0,0 +1,101 @@ +import math +from typing import Iterator, Optional, Sized + +import torch +from mmengine.dist import get_dist_info, is_main_process, sync_random_seed +from torch.utils.data import Sampler + +from mmpretrain.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class RepeatAugSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset for + distributed, with repeated augmentation. It ensures that different each + augmented version of a sample will be visible to a different process (GPU). + Heavily based on torch.utils.data.DistributedSampler. + + This sampler was taken from + https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py + Used in + Copyright (c) 2015-present, Facebook, Inc. + + Args: + dataset (Sized): The dataset. + shuffle (bool): Whether shuffle the dataset or not. Defaults to True. + num_repeats (int): The repeat times of every sample. Defaults to 3. + seed (int, optional): Random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Defaults to None. + """ + + def __init__(self, + dataset: Sized, + shuffle: bool = True, + num_repeats: int = 3, + seed: Optional[int] = None): + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.shuffle = shuffle + if not self.shuffle and is_main_process(): + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.warning('The RepeatAugSampler always picks a ' + 'fixed part of data if `shuffle=False`.') + + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.epoch = 0 + self.num_repeats = num_repeats + + # The number of repeated samples in the rank + self.num_samples = math.ceil( + len(self.dataset) * num_repeats / world_size) + # The total number of repeated samples in all ranks. + self.total_size = self.num_samples * world_size + # The number of selected samples in the rank + self.num_selected_samples = math.ceil(len(self.dataset) / world_size) + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + # deterministically shuffle based on epoch and seed + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....] + indices = [x for x in indices for _ in range(self.num_repeats)] + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + indices += indices[:padding_size] + assert len(indices) == self.total_size + + # subsample per rank + indices = indices[self.rank:self.total_size:self.world_size] + assert len(indices) == self.num_samples + + # return up to num selected samples + return iter(indices[:self.num_selected_samples]) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_selected_samples + + def set_epoch(self, epoch: int) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/mmpretrain/datasets/samplers/sequential.py b/mmpretrain/datasets/samplers/sequential.py new file mode 100644 index 0000000..e3b940c --- /dev/null +++ b/mmpretrain/datasets/samplers/sequential.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterator + +import torch +from mmengine.dataset import DefaultSampler + +from mmpretrain.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class SequentialSampler(DefaultSampler): + """Sequential sampler which supports different subsample policy. + + Args: + dataset (Sized): The dataset. + round_up (bool): Whether to add extra samples to make the number of + samples evenly divisible by the world size. Defaults to True. + subsample_type (str): The method to subsample data on different rank. + Supported type: + + - ``'default'``: Original torch behavior. Sample the examples one + by one for each GPU in terms. For instance, 8 examples on 2 GPUs, + GPU0: [0,2,4,8], GPU1: [1,3,5,7] + - ``'sequential'``: Subsample all examples to n chunk sequntially. + For instance, 8 examples on 2 GPUs, + GPU0: [0,1,2,3], GPU1: [4,5,6,7] + """ + + def __init__(self, subsample_type: str = 'default', **kwargs) -> None: + super().__init__(shuffle=False, **kwargs) + + if subsample_type not in ['default', 'sequential']: + raise ValueError(f'Unsupported subsample typer "{subsample_type}",' + ' please choose from ["default", "sequential"]') + self.subsample_type = subsample_type + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + + # subsample + if self.subsample_type == 'default': + indices = indices[self.rank:self.total_size:self.world_size] + elif self.subsample_type == 'sequential': + num_samples_per_rank = self.total_size // self.world_size + indices = indices[self.rank * + num_samples_per_rank:(self.rank + 1) * + num_samples_per_rank] + + return iter(indices) diff --git a/mmpretrain/datasets/scienceqa.py b/mmpretrain/datasets/scienceqa.py new file mode 100644 index 0000000..8e44249 --- /dev/null +++ b/mmpretrain/datasets/scienceqa.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Callable, List, Sequence + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class ScienceQA(BaseDataset): + """ScienceQA dataset. + + This dataset is used to load the multimodal data of ScienceQA dataset. + + Args: + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. + split (str): The split of dataset. Options: ``train``, ``val``, + ``test``, ``trainval``, ``minival``, and ``minitest``. + split_file (str): The split file of dataset, which contains the + ids of data samples in the split. + ann_file (str): Annotation file path. + image_only (bool): Whether only to load data with image. Defaults to + False. + data_prefix (dict): Prefix for data field. Defaults to + ``dict(img_path='')``. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + split: str, + split_file: str, + ann_file: str, + image_only: bool = False, + data_prefix: dict = dict(img_path=''), + pipeline: Sequence[Callable] = (), + **kwargs): + assert split in [ + 'train', 'val', 'test', 'trainval', 'minival', 'minitest' + ], f'Invalid split {split}' + self.split = split + self.split_file = os.path.join(data_root, split_file) + self.image_only = image_only + + super().__init__( + data_root=data_root, + ann_file=ann_file, + data_prefix=data_prefix, + pipeline=pipeline, + **kwargs) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + img_prefix = self.data_prefix['img_path'] + annotations = mmengine.load(self.ann_file) + current_data_split = mmengine.load(self.split_file)[self.split] # noqa + + file_backend = get_file_backend(img_prefix) + + data_list = [] + for data_id in current_data_split: + ann = annotations[data_id] + if self.image_only and ann['image'] is None: + continue + data_info = { + 'image_id': + data_id, + 'question': + ann['question'], + 'choices': + ann['choices'], + 'gt_answer': + ann['answer'], + 'hint': + ann['hint'], + 'image_name': + ann['image'], + 'task': + ann['task'], + 'grade': + ann['grade'], + 'subject': + ann['subject'], + 'topic': + ann['topic'], + 'category': + ann['category'], + 'skill': + ann['skill'], + 'lecture': + ann['lecture'], + 'solution': + ann['solution'], + 'split': + ann['split'], + 'img_path': + file_backend.join_path(img_prefix, data_id, ann['image']) + if ann['image'] is not None else None, + 'has_image': + True if ann['image'] is not None else False, + } + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/stanfordcars.py b/mmpretrain/datasets/stanfordcars.py new file mode 100644 index 0000000..3556979 --- /dev/null +++ b/mmpretrain/datasets/stanfordcars.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mat4py +from mmengine import get_file_backend + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import STANFORDCARS_CATEGORIES + + +@DATASETS.register_module() +class StanfordCars(BaseDataset): + """The Stanford Cars Dataset. + + Support the `Stanford Cars Dataset `_ Dataset. + The official website provides two ways to organize the dataset. + Therefore, after downloading and decompression, the dataset directory structure is as follows. + + Stanford Cars dataset directory: :: + + Stanford_Cars + ├── car_ims + │ ├── 00001.jpg + │ ├── 00002.jpg + │ └── ... + └── cars_annos.mat + + or :: + + Stanford_Cars + ├── cars_train + │ ├── 00001.jpg + │ ├── 00002.jpg + │ └── ... + ├── cars_test + │ ├── 00001.jpg + │ ├── 00002.jpg + │ └── ... + └── devkit + ├── cars_meta.mat + ├── cars_train_annos.mat + ├── cars_test_annos.mat + ├── cars_test_annoswithlabels.mat + ├── eval_train.m + └── train_perfect_preds.txt + + Args: + data_root (str): The root directory for Stanford Cars dataset. + split (str, optional): The dataset split, supports "train" + and "test". Default to "train". + + Examples: + >>> from mmpretrain.datasets import StanfordCars + >>> train_dataset = StanfordCars(data_root='data/Stanford_Cars', split='train') + >>> train_dataset + Dataset StanfordCars + Number of samples: 8144 + Number of categories: 196 + Root of dataset: data/Stanford_Cars + >>> test_dataset = StanfordCars(data_root='data/Stanford_Cars', split='test') + >>> test_dataset + Dataset StanfordCars + Number of samples: 8041 + Number of categories: 196 + Root of dataset: data/Stanford_Cars + """ # noqa: E501 + + METAINFO = {'classes': STANFORDCARS_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + test_mode = split == 'test' + self.backend = get_file_backend(data_root, enable_singleton=True) + + anno_file_path = self.backend.join_path(data_root, 'cars_annos.mat') + if self.backend.exists(anno_file_path): + ann_file = 'cars_annos.mat' + data_prefix = '' + else: + if test_mode: + ann_file = self.backend.join_path( + 'devkit', 'cars_test_annos_withlabels.mat') + data_prefix = 'cars_test' + else: + ann_file = self.backend.join_path('devkit', + 'cars_train_annos.mat') + data_prefix = 'cars_train' + + if not self.backend.exists( + self.backend.join_path(data_root, ann_file)): + doc_url = 'https://mmpretrain.readthedocs.io/en/latest/api/datasets.html#stanfordcars' # noqa: E501 + raise RuntimeError( + f'The dataset is incorrectly organized, please \ + refer to {doc_url} and reorganize your folders.') + + super(StanfordCars, self).__init__( + ann_file=ann_file, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self): + data = mat4py.loadmat(self.ann_file)['annotations'] + + data_list = [] + if 'test' in data.keys(): + # first way + img_paths, labels, test = data['relative_im_path'], data[ + 'class'], data['test'] + num = len(img_paths) + assert num == len(labels) == len(test), 'get error ann file' + for i in range(num): + if not self.test_mode and test[i] == 1: + continue + if self.test_mode and test[i] == 0: + continue + img_path = self.backend.join_path(self.img_prefix, + img_paths[i]) + gt_label = labels[i] - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + else: + # second way + img_names, labels = data['fname'], data['class'] + num = len(img_names) + assert num == len(labels), 'get error ann file' + for i in range(num): + img_path = self.backend.join_path(self.img_prefix, + img_names[i]) + gt_label = labels[i] - 1 + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/sun397.py b/mmpretrain/datasets/sun397.py new file mode 100644 index 0000000..1039a06 --- /dev/null +++ b/mmpretrain/datasets/sun397.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine import get_file_backend, list_from_file + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset +from .categories import SUN397_CATEGORIES + + +@DATASETS.register_module() +class SUN397(BaseDataset): + """The SUN397 Dataset. + + Support the `SUN397 Dataset `_ Dataset. + After downloading and decompression, the dataset directory structure is as follows. + + SUN397 dataset directory: :: + + SUN397 + ├── SUN397 + │ ├── a + │ │ ├── abbey + │ | | ├── sun_aaalbzqrimafwbiv.jpg + │ | | └── ... + │ │ ├── airplane_cabin + │ | | ├── sun_aadqdkqaslqqoblu.jpg + │ | | └── ... + │ | └── ... + │ ├── b + │ │ └── ... + │ ├── c + │ │ └── ... + │ └── ... + └── Partitions + ├── ClassName.txt + ├── Training_01.txt + ├── Testing_01.txt + └── ... + + Args: + data_root (str): The root directory for Stanford Cars dataset. + split (str, optional): The dataset split, supports "train" and "test". + Default to "train". + + Examples: + >>> from mmpretrain.datasets import SUN397 + >>> train_dataset = SUN397(data_root='data/SUN397', split='train') + >>> train_dataset + Dataset SUN397 + Number of samples: 19850 + Number of categories: 397 + Root of dataset: data/SUN397 + >>> test_dataset = SUN397(data_root='data/SUN397', split='test') + >>> test_dataset + Dataset SUN397 + Number of samples: 19850 + Number of categories: 397 + Root of dataset: data/SUN397 + + **Note that some images are not a jpg file although the name ends with ".jpg". + The backend of SUN397 should be "pillow" as below to read these images properly,** + + .. code-block:: python + + pipeline = [ + dict(type='LoadImageFromFile', imdecode_backend='pillow'), + dict(type='RandomResizedCrop', scale=224), + dict(type='PackInputs') + ] + """ # noqa: E501 + + METAINFO = {'classes': SUN397_CATEGORIES} + + def __init__(self, data_root: str, split: str = 'train', **kwargs): + + splits = ['train', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + self.backend = get_file_backend(data_root, enable_singleton=True) + if split == 'train': + ann_file = self.backend.join_path('Partitions', 'Training_01.txt') + else: + ann_file = self.backend.join_path('Partitions', 'Testing_01.txt') + + data_prefix = 'SUN397' + test_mode = split == 'test' + + super(SUN397, self).__init__( + ann_file=ann_file, + data_root=data_root, + test_mode=test_mode, + data_prefix=data_prefix, + **kwargs) + + def load_data_list(self): + pairs = list_from_file(self.ann_file) + data_list = [] + for pair in pairs: + img_path = self.backend.join_path(self.img_prefix, pair[1:]) + items = pair.split('/') + class_name = '_'.join(items[2:-1]) + gt_label = self.METAINFO['classes'].index(class_name) + info = dict(img_path=img_path, gt_label=gt_label) + data_list.append(info) + + return data_list + + def __getitem__(self, idx: int) -> dict: + try: + return super().__getitem__(idx) + except AttributeError: + raise RuntimeError( + 'Some images in the SUN397 dataset are not a jpg file ' + 'although the name ends with ".jpg". The backend of SUN397 ' + 'should be "pillow" to read these images properly.') + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Root of dataset: \t{self.data_root}', + ] + return body diff --git a/mmpretrain/datasets/textvqa.py b/mmpretrain/datasets/textvqa.py new file mode 100644 index 0000000..48a82b4 --- /dev/null +++ b/mmpretrain/datasets/textvqa.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import Counter +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class TextVQA(BaseDataset): + """TextVQA dataset. + + val image: + https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip + test image: + https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip + val json: + https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json + test json: + https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json + + folder structure: + data/textvqa + ├── annotations + │ ├── TextVQA_0.5.1_test.json + │ └── TextVQA_0.5.1_val.json + └── images + ├── test_images + └── train_images + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + question_file (str): Question file path. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file)['data'] + + data_list = [] + + for ann in annotations: + + # ann example + # { + # 'question': 'what is the brand of...is camera?', + # 'image_id': '003a8ae2ef43b901', + # 'image_classes': [ + # 'Cassette deck', 'Printer', ... + # ], + # 'flickr_original_url': 'https://farm2.static...04a6_o.jpg', + # 'flickr_300k_url': 'https://farm2.static...04a6_o.jpg', + # 'image_width': 1024, + # 'image_height': 664, + # 'answers': [ + # 'nous les gosses', + # 'dakota', + # 'clos culombu', + # 'dakota digital' ... + # ], + # 'question_tokens': + # ['what', 'is', 'the', 'brand', 'of', 'this', 'camera'], + # 'question_id': 34602, + # 'set_name': 'val' + # } + + data_info = dict(question=ann['question']) + data_info['question_id'] = ann['question_id'] + data_info['image_id'] = ann['image_id'] + + img_path = mmengine.join_path(self.data_prefix['img_path'], + ann['image_id'] + '.jpg') + data_info['img_path'] = img_path + + data_info['question_id'] = ann['question_id'] + + if 'answers' in ann: + answers = [item for item in ann.pop('answers')] + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + data_info['gt_answer'] = list(count.keys()) + data_info['gt_answer_weight'] = answer_weight + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/transforms/__init__.py b/mmpretrain/datasets/transforms/__init__.py new file mode 100644 index 0000000..617503f --- /dev/null +++ b/mmpretrain/datasets/transforms/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms import (CenterCrop, LoadImageFromFile, Normalize, + RandomFlip, RandomGrayscale, RandomResize, Resize) + +from mmpretrain.registry import TRANSFORMS +from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform, + Brightness, ColorTransform, Contrast, Cutout, + Equalize, GaussianBlur, Invert, Posterize, + RandAugment, Rotate, Sharpness, Shear, Solarize, + SolarizeAdd, Translate) +from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs, + PILToNumpy, Transpose) +from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption, + ColorJitter, EfficientNetCenterCrop, + EfficientNetRandomCrop, Lighting, + MAERandomResizedCrop, RandomCrop, RandomErasing, + RandomResizedCrop, + RandomResizedCropAndInterpolationWithTwoPic, + RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator) +from .utils import get_transform_idx, remove_transform +from .wrappers import ApplyToList, MultiView + +for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip, + RandomGrayscale, RandomResize, Resize): + TRANSFORMS.register_module(module=t) + +__all__ = [ + 'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop', + 'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert', + 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize', + 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd', + 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', + 'PackInputs', 'Albumentations', 'EfficientNetRandomCrop', + 'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform', + 'PackMultiTaskInputs', 'GaussianBlur', 'BEiTMaskGenerator', + 'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize', + 'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView', + 'ApplyToList', 'CleanCaption', 'RandomTranslatePad', + 'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx', + 'remove_transform', 'MAERandomResizedCrop' +] diff --git a/mmpretrain/datasets/transforms/auto_augment.py b/mmpretrain/datasets/transforms/auto_augment.py new file mode 100644 index 0000000..4705d5e --- /dev/null +++ b/mmpretrain/datasets/transforms/auto_augment.py @@ -0,0 +1,1244 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +from copy import deepcopy +from math import ceil +from numbers import Number +from typing import List, Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +from mmcv.transforms import BaseTransform, Compose, RandomChoice +from mmcv.transforms.utils import cache_randomness +from mmengine.utils import is_list_of, is_seq_of +from PIL import Image, ImageFilter + +from mmpretrain.registry import TRANSFORMS + + +def merge_hparams(policy: dict, hparams: dict) -> dict: + """Merge hyperparameters into policy config. + + Only merge partial hyperparameters required of the policy. + + Args: + policy (dict): Original policy config dict. + hparams (dict): Hyperparameters need to be merged. + + Returns: + dict: Policy config dict after adding ``hparams``. + """ + policy = deepcopy(policy) + op = TRANSFORMS.get(policy['type']) + assert op is not None, f'Invalid policy type "{policy["type"]}".' + + op_args = inspect.getfullargspec(op.__init__).args + for key, value in hparams.items(): + if key in op_args and key not in policy: + policy[key] = value + return policy + + +@TRANSFORMS.register_module() +class AutoAugment(RandomChoice): + """Auto augmentation. + + This data augmentation is proposed in `AutoAugment: Learning Augmentation + Policies from Data `_. + + Args: + policies (str | list[list[dict]]): The policies of auto augmentation. + If string, use preset policies collection like "imagenet". If list, + Each item is a sub policies, composed by several augmentation + policy dicts. When AutoAugment is called, a random sub policies in + ``policies`` will be selected to augment images. + hparams (dict): Configs of hyperparameters. Hyperparameters will be + used in policies that require these arguments if these arguments + are not set in policy dicts. Defaults to ``dict(pad_val=128)``. + + .. admonition:: Available preset policies + + - ``"imagenet"``: Policy for ImageNet, come from + `DeepVoltaire/AutoAugment`_ + + .. _DeepVoltaire/AutoAugment: https://github.com/DeepVoltaire/AutoAugment + """ + + def __init__(self, + policies: Union[str, List[List[dict]]], + hparams: dict = dict(pad_val=128)): + if isinstance(policies, str): + assert policies in AUTOAUG_POLICIES, 'Invalid policies, ' \ + f'please choose from {list(AUTOAUG_POLICIES.keys())}.' + policies = AUTOAUG_POLICIES[policies] + self.hparams = hparams + self.policies = [[merge_hparams(t, hparams) for t in sub] + for sub in policies] + transforms = [[TRANSFORMS.build(t) for t in sub] for sub in policies] + + super().__init__(transforms=transforms) + + def __repr__(self) -> str: + policies_str = '' + for sub in self.policies: + policies_str += '\n ' + ', \t'.join([t['type'] for t in sub]) + + repr_str = self.__class__.__name__ + repr_str += f'(policies:{policies_str}\n)' + return repr_str + + +@TRANSFORMS.register_module() +class RandAugment(BaseTransform): + r"""Random augmentation. + + This data augmentation is proposed in `RandAugment: Practical automated + data augmentation with a reduced search space + `_. + + Args: + policies (str | list[dict]): The policies of random augmentation. + If string, use preset policies collection like "timm_increasing". + If list, each item is one specific augmentation policy dict. + The policy dict shall should have these keys: + + - ``type`` (str), The type of augmentation. + - ``magnitude_range`` (Sequence[number], optional): For those + augmentation have magnitude, you need to specify the magnitude + level mapping range. For example, assume ``total_level`` is 10, + ``magnitude_level=3`` specify magnitude is 3 if + ``magnitude_range=(0, 10)`` while specify magnitude is 7 if + ``magnitude_range=(10, 0)``. + - other keyword arguments of the augmentation. + + num_policies (int): Number of policies to select from policies each + time. + magnitude_level (int | float): Magnitude level for all the augmentation + selected. + magnitude_std (Number | str): Deviation of magnitude noise applied. + + - If positive number, the magnitude obeys normal distribution + :math:`\mathcal{N}(magnitude_level, magnitude_std)`. + - If 0 or negative number, magnitude remains unchanged. + - If str "inf", the magnitude obeys uniform distribution + :math:`Uniform(min, magnitude)`. + total_level (int | float): Total level for the magnitude. Defaults to + 10. + hparams (dict): Configs of hyperparameters. Hyperparameters will be + used in policies that require these arguments if these arguments + are not set in policy dicts. Defaults to ``dict(pad_val=128)``. + + .. admonition:: Available preset policies + + - ``"timm_increasing"``: The ``_RAND_INCREASING_TRANSFORMS`` policy + from `timm`_ + + .. _timm: https://github.com/rwightman/pytorch-image-models + + Examples: + + To use "timm-increasing" policies collection, select two policies every + time, and magnitude_level of every policy is 6 (total is 10 by default) + + >>> import numpy as np + >>> from mmpretrain.datasets import RandAugment + >>> transform = RandAugment( + ... policies='timm_increasing', + ... num_policies=2, + ... magnitude_level=6, + ... ) + >>> data = {'img': np.random.randint(0, 256, (224, 224, 3))} + >>> results = transform(data) + >>> print(results['img'].shape) + (224, 224, 3) + + If you want the ``magnitude_level`` randomly changes every time, you + can use ``magnitude_std`` to specify the random distribution. For + example, a normal distribution :math:`\mathcal{N}(6, 0.5)`. + + >>> transform = RandAugment( + ... policies='timm_increasing', + ... num_policies=2, + ... magnitude_level=6, + ... magnitude_std=0.5, + ... ) + + You can also use your own policies: + + >>> policies = [ + ... dict(type='AutoContrast'), + ... dict(type='Rotate', magnitude_range=(0, 30)), + ... dict(type='ColorTransform', magnitude_range=(0, 0.9)), + ... ] + >>> transform = RandAugment( + ... policies=policies, + ... num_policies=2, + ... magnitude_level=6 + ... ) + + Note: + ``magnitude_std`` will introduce some randomness to policy, modified by + https://github.com/rwightman/pytorch-image-models. + + When magnitude_std=0, we calculate the magnitude as follows: + + .. math:: + \text{magnitude} = \frac{\text{magnitude_level}} + {\text{totallevel}} \times (\text{val2} - \text{val1}) + + \text{val1} + """ + + def __init__(self, + policies: Union[str, List[dict]], + num_policies: int, + magnitude_level: int, + magnitude_std: Union[Number, str] = 0., + total_level: int = 10, + hparams: dict = dict(pad_val=128)): + if isinstance(policies, str): + assert policies in RANDAUG_POLICIES, 'Invalid policies, ' \ + f'please choose from {list(RANDAUG_POLICIES.keys())}.' + policies = RANDAUG_POLICIES[policies] + + assert is_list_of(policies, dict), 'policies must be a list of dict.' + + assert isinstance(magnitude_std, (Number, str)), \ + '`magnitude_std` must be of number or str type, ' \ + f'got {type(magnitude_std)} instead.' + if isinstance(magnitude_std, str): + assert magnitude_std == 'inf', \ + '`magnitude_std` must be of number or "inf", ' \ + f'got "{magnitude_std}" instead.' + + assert num_policies > 0, 'num_policies must be greater than 0.' + assert magnitude_level >= 0, 'magnitude_level must be no less than 0.' + assert total_level > 0, 'total_level must be greater than 0.' + + self.num_policies = num_policies + self.magnitude_level = magnitude_level + self.magnitude_std = magnitude_std + self.total_level = total_level + self.hparams = hparams + self.policies = [] + self.transforms = [] + + randaug_cfg = dict( + magnitude_level=magnitude_level, + total_level=total_level, + magnitude_std=magnitude_std) + + for policy in policies: + self._check_policy(policy) + policy = merge_hparams(policy, hparams) + policy.pop('magnitude_key', None) # For backward compatibility + if 'magnitude_range' in policy: + policy.update(randaug_cfg) + self.policies.append(policy) + self.transforms.append(TRANSFORMS.build(policy)) + + def __iter__(self): + """Iterate all transforms.""" + return iter(self.transforms) + + def _check_policy(self, policy): + """Check whether the sub-policy dict is available.""" + assert isinstance(policy, dict) and 'type' in policy, \ + 'Each policy must be a dict with key "type".' + type_name = policy['type'] + + if 'magnitude_range' in policy: + magnitude_range = policy['magnitude_range'] + assert is_seq_of(magnitude_range, Number), \ + f'`magnitude_range` of RandAugment policy {type_name} ' \ + 'should be a sequence with two numbers.' + + @cache_randomness + def random_policy_indices(self) -> np.ndarray: + """Return the random chosen transform indices.""" + indices = np.arange(len(self.policies)) + return np.random.choice(indices, size=self.num_policies).tolist() + + def transform(self, results: dict) -> Optional[dict]: + """Randomly choose a sub-policy to apply.""" + + chosen_policies = [ + self.transforms[i] for i in self.random_policy_indices() + ] + + sub_pipeline = Compose(chosen_policies) + return sub_pipeline(results) + + def __repr__(self) -> str: + policies_str = '' + for policy in self.policies: + policies_str += '\n ' + f'{policy["type"]}' + if 'magnitude_range' in policy: + val1, val2 = policy['magnitude_range'] + policies_str += f' ({val1}, {val2})' + + repr_str = self.__class__.__name__ + repr_str += f'(num_policies={self.num_policies}, ' + repr_str += f'magnitude_level={self.magnitude_level}, ' + repr_str += f'total_level={self.total_level}, ' + repr_str += f'policies:{policies_str}\n)' + return repr_str + + +class BaseAugTransform(BaseTransform): + r"""The base class of augmentation transform for RandAugment. + + This class provides several common attributions and methods to support the + magnitude level mapping and magnitude level randomness in + :class:`RandAugment`. + + Args: + magnitude_level (int | float): Magnitude level. + magnitude_range (Sequence[number], optional): For augmentation have + magnitude argument, maybe "magnitude", "angle" or other, you can + specify the magnitude level mapping range to generate the magnitude + argument. For example, assume ``total_level`` is 10, + ``magnitude_level=3`` specify magnitude is 3 if + ``magnitude_range=(0, 10)`` while specify magnitude is 7 if + ``magnitude_range=(10, 0)``. Defaults to None. + magnitude_std (Number | str): Deviation of magnitude noise applied. + + - If positive number, the magnitude obeys normal distribution + :math:`\mathcal{N}(magnitude, magnitude_std)`. + - If 0 or negative number, magnitude remains unchanged. + - If str "inf", the magnitude obeys uniform distribution + :math:`Uniform(min, magnitude)`. + + Defaults to 0. + total_level (int | float): Total level for the magnitude. Defaults to + 10. + prob (float): The probability for performing transformation therefore + should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0. + """ + + def __init__(self, + magnitude_level: int = 10, + magnitude_range: Tuple[float, float] = None, + magnitude_std: Union[str, float] = 0., + total_level: int = 10, + prob: float = 0.5, + random_negative_prob: float = 0.5): + self.magnitude_level = magnitude_level + self.magnitude_range = magnitude_range + self.magnitude_std = magnitude_std + self.total_level = total_level + self.prob = prob + self.random_negative_prob = random_negative_prob + + @cache_randomness + def random_disable(self): + """Randomly disable the transform.""" + return np.random.rand() > self.prob + + @cache_randomness + def random_magnitude(self): + """Randomly generate magnitude.""" + magnitude = self.magnitude_level + # if magnitude_std is positive number or 'inf', move + # magnitude_value randomly. + if self.magnitude_std == 'inf': + magnitude = np.random.uniform(0, magnitude) + elif self.magnitude_std > 0: + magnitude = np.random.normal(magnitude, self.magnitude_std) + magnitude = np.clip(magnitude, 0, self.total_level) + + val1, val2 = self.magnitude_range + magnitude = (magnitude / self.total_level) * (val2 - val1) + val1 + return magnitude + + @cache_randomness + def random_negative(self, value): + """Randomly negative the value.""" + if np.random.rand() < self.random_negative_prob: + return -value + else: + return value + + def extra_repr(self): + """Extra repr string when auto-generating magnitude is enabled.""" + if self.magnitude_range is not None: + repr_str = f', magnitude_level={self.magnitude_level}, ' + repr_str += f'magnitude_range={self.magnitude_range}, ' + repr_str += f'magnitude_std={self.magnitude_std}, ' + repr_str += f'total_level={self.total_level}, ' + return repr_str + else: + return '' + + +@TRANSFORMS.register_module() +class Shear(BaseAugTransform): + """Shear images. + + Args: + magnitude (int | float | None): The magnitude used for shear. If None, + generate from ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If a sequence of length 3, it is used to pad_val R, G, B channels + respectively. Defaults to 128. + prob (float): The probability for performing shear therefore should be + in range [0, 1]. Defaults to 0.5. + direction (str): The shearing direction. Options are 'horizontal' and + 'vertical'. Defaults to 'horizontal'. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + interpolation (str): Interpolation method. Options are 'nearest', + 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'bicubic'. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + direction: str = 'horizontal', + random_negative_prob: float = 0.5, + interpolation: str = 'bicubic', + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + assert direction in ('horizontal', 'vertical'), 'direction must be ' \ + f'either "horizontal" or "vertical", got "{direction}" instead.' + self.direction = direction + + self.interpolation = interpolation + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_sheared = mmcv.imshear( + img, + magnitude, + direction=self.direction, + border_value=self.pad_val, + interpolation=self.interpolation) + results['img'] = img_sheared.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'direction={self.direction}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}, ' + repr_str += f'interpolation={self.interpolation}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Translate(BaseAugTransform): + """Translate images. + + Args: + magnitude (int | float | None): The magnitude used for translate. Note + that the offset is calculated by magnitude * size in the + corresponding direction. With a magnitude of 1, the whole image + will be moved out of the range. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If a sequence of length 3, it is used to pad_val R, G, B channels + respectively. Defaults to 128. + prob (float): The probability for performing translate therefore should + be in range [0, 1]. Defaults to 0.5. + direction (str): The translating direction. Options are 'horizontal' + and 'vertical'. Defaults to 'horizontal'. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + interpolation (str): Interpolation method. Options are 'nearest', + 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'nearest'. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + direction: str = 'horizontal', + random_negative_prob: float = 0.5, + interpolation: str = 'nearest', + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + assert direction in ('horizontal', 'vertical'), 'direction must be ' \ + f'either "horizontal" or "vertical", got "{direction}" instead.' + self.direction = direction + + self.interpolation = interpolation + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + height, width = img.shape[:2] + if self.direction == 'horizontal': + offset = magnitude * width + else: + offset = magnitude * height + img_translated = mmcv.imtranslate( + img, + offset, + direction=self.direction, + border_value=self.pad_val, + interpolation=self.interpolation) + results['img'] = img_translated.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'direction={self.direction}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}, ' + repr_str += f'interpolation={self.interpolation}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Rotate(BaseAugTransform): + """Rotate images. + + Args: + angle (float, optional): The angle used for rotate. Positive values + stand for clockwise rotation. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If None, the center of the image will be used. + Defaults to None. + scale (float): Isotropic scale factor. Defaults to 1.0. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If a sequence of length 3, it is used to pad_val R, G, B channels + respectively. Defaults to 128. + prob (float): The probability for performing rotate therefore should be + in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the angle + negative, which should be in range [0,1]. Defaults to 0.5. + interpolation (str): Interpolation method. Options are 'nearest', + 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'nearest'. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + angle: Optional[float] = None, + center: Optional[Tuple[float]] = None, + scale: float = 1.0, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + random_negative_prob: float = 0.5, + interpolation: str = 'nearest', + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (angle is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `angle` and `magnitude_range`.' + + self.angle = angle + self.center = center + self.scale = scale + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + self.interpolation = interpolation + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.angle is not None: + angle = self.random_negative(self.angle) + else: + angle = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_rotated = mmcv.imrotate( + img, + angle, + center=self.center, + scale=self.scale, + border_value=self.pad_val, + interpolation=self.interpolation) + results['img'] = img_rotated.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(angle={self.angle}, ' + repr_str += f'center={self.center}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}, ' + repr_str += f'interpolation={self.interpolation}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class AutoContrast(BaseAugTransform): + """Auto adjust image contrast. + + Args: + prob (float): The probability for performing auto contrast + therefore should be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, prob: float = 0.5, **kwargs): + super().__init__(prob=prob, **kwargs) + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + img = results['img'] + img_contrasted = mmcv.auto_contrast(img) + results['img'] = img_contrasted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob})' + return repr_str + + +@TRANSFORMS.register_module() +class Invert(BaseAugTransform): + """Invert images. + + Args: + prob (float): The probability for performing invert therefore should + be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, prob: float = 0.5, **kwargs): + super().__init__(prob=prob, **kwargs) + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + img = results['img'] + img_inverted = mmcv.iminvert(img) + results['img'] = img_inverted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob})' + return repr_str + + +@TRANSFORMS.register_module() +class Equalize(BaseAugTransform): + """Equalize the image histogram. + + Args: + prob (float): The probability for performing equalize therefore should + be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, prob: float = 0.5, **kwargs): + super().__init__(prob=prob, **kwargs) + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + img = results['img'] + img_equalized = mmcv.imequalize(img) + results['img'] = img_equalized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob})' + return repr_str + + +@TRANSFORMS.register_module() +class Solarize(BaseAugTransform): + """Solarize images (invert all pixel values above a threshold). + + Args: + thr (int | float | None): The threshold above which the pixels value + will be inverted. If None, generate from ``magnitude_range``, + see :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for solarizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + thr: Union[int, float, None] = None, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (thr is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `thr` and `magnitude_range`.' + + self.thr = thr + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.thr is not None: + thr = self.thr + else: + thr = self.random_magnitude() + + img = results['img'] + img_solarized = mmcv.solarize(img, thr=thr) + results['img'] = img_solarized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(thr={self.thr}, ' + repr_str += f'prob={self.prob}{self.extra_repr()}))' + return repr_str + + +@TRANSFORMS.register_module() +class SolarizeAdd(BaseAugTransform): + """SolarizeAdd images (add a certain value to pixels below a threshold). + + Args: + magnitude (int | float | None): The value to be added to pixels below + the thr. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + thr (int | float): The threshold below which the pixels value will be + adjusted. + prob (float): The probability for solarizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + thr: Union[int, float] = 128, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + assert isinstance(thr, (int, float)), 'The thr type must '\ + f'be int or float, but got {type(thr)} instead.' + self.thr = thr + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.magnitude + else: + magnitude = self.random_magnitude() + + img = results['img'] + img_solarized = np.where(img < self.thr, + np.minimum(img + magnitude, 255), img) + results['img'] = img_solarized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'thr={self.thr}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Posterize(BaseAugTransform): + """Posterize images (reduce the number of bits for each color channel). + + Args: + bits (int, optional): Number of bits for each pixel in the output img, + which should be less or equal to 8. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + prob (float): The probability for posterizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + bits: Optional[int] = None, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (bits is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `bits` and `magnitude_range`.' + + if bits is not None: + assert bits <= 8, \ + f'The bits must be less than 8, got {bits} instead.' + self.bits = bits + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.bits is not None: + bits = self.bits + else: + bits = self.random_magnitude() + + # To align timm version, we need to round up to integer here. + bits = ceil(bits) + + img = results['img'] + img_posterized = mmcv.posterize(img, bits=bits) + results['img'] = img_posterized.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(bits={self.bits}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Contrast(BaseAugTransform): + """Adjust images contrast. + + Args: + magnitude (int | float | None): The magnitude used for adjusting + contrast. A positive magnitude would enhance the contrast and + a negative magnitude would make the image grayer. A magnitude=0 + gives the origin img. If None, generate from ``magnitude_range``, + see :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing contrast adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_contrasted = mmcv.adjust_contrast(img, factor=1 + magnitude) + results['img'] = img_contrasted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class ColorTransform(BaseAugTransform): + """Adjust images color balance. + + Args: + magnitude (int | float | None): The magnitude used for color transform. + A positive magnitude would enhance the color and a negative + magnitude would make the image grayer. A magnitude=0 gives the + origin img. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing ColorTransform therefore + should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_color_adjusted = mmcv.adjust_color(img, alpha=1 + magnitude) + results['img'] = img_color_adjusted.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Brightness(BaseAugTransform): + """Adjust images brightness. + + Args: + magnitude (int | float | None): The magnitude used for adjusting + brightness. A positive magnitude would enhance the brightness and a + negative magnitude would make the image darker. A magnitude=0 gives + the origin img. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing brightness adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_brightened = mmcv.adjust_brightness(img, factor=1 + magnitude) + results['img'] = img_brightened.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Sharpness(BaseAugTransform): + """Adjust images sharpness. + + Args: + magnitude (int | float | None): The magnitude used for adjusting + sharpness. A positive magnitude would enhance the sharpness and a + negative magnitude would make the image bulr. A magnitude=0 gives + the origin img. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + prob (float): The probability for performing sharpness adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + magnitude: Union[int, float, None] = None, + prob: float = 0.5, + random_negative_prob: float = 0.5, + **kwargs): + super().__init__( + prob=prob, random_negative_prob=random_negative_prob, **kwargs) + assert (magnitude is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `magnitude` and `magnitude_range`.' + + self.magnitude = magnitude + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.magnitude is not None: + magnitude = self.random_negative(self.magnitude) + else: + magnitude = self.random_negative(self.random_magnitude()) + + img = results['img'] + img_sharpened = mmcv.adjust_sharpness(img, factor=1 + magnitude) + results['img'] = img_sharpened.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob}' + repr_str += f'{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class Cutout(BaseAugTransform): + """Cutout images. + + Args: + shape (int | tuple(int) | None): Expected cutout shape (h, w). + If given as a single value, the value will be used for both h and + w. If None, generate from ``magnitude_range``, see + :class:`BaseAugTransform`. Defaults to None. + pad_val (int, Sequence[int]): Pixel pad_val value for constant fill. + If it is a sequence, it must have the same length with the image + channels. Defaults to 128. + prob (float): The probability for performing cutout therefore should + be in range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + shape: Union[int, Tuple[int], None] = None, + pad_val: Union[int, Sequence[int]] = 128, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (shape is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `shape` and `magnitude_range`.' + + self.shape = shape + if isinstance(pad_val, Sequence): + self.pad_val = tuple(pad_val) + else: + self.pad_val = pad_val + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.shape is not None: + shape = self.shape + else: + shape = int(self.random_magnitude()) + + img = results['img'] + img_cutout = mmcv.cutout(img, shape, pad_val=self.pad_val) + results['img'] = img_cutout.astype(img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(shape={self.shape}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +@TRANSFORMS.register_module() +class GaussianBlur(BaseAugTransform): + """Gaussian blur images. + + Args: + radius (int, float, optional): The blur radius. If None, generate from + ``magnitude_range``, see :class:`BaseAugTransform`. + Defaults to None. + prob (float): The probability for posterizing therefore should be in + range [0, 1]. Defaults to 0.5. + **kwargs: Other keyword arguments of :class:`BaseAugTransform`. + """ + + def __init__(self, + radius: Union[int, float, None] = None, + prob: float = 0.5, + **kwargs): + super().__init__(prob=prob, random_negative_prob=0., **kwargs) + assert (radius is None) ^ (self.magnitude_range is None), \ + 'Please specify only one of `radius` and `magnitude_range`.' + + self.radius = radius + + def transform(self, results): + """Apply transform to results.""" + if self.random_disable(): + return results + + if self.radius is not None: + radius = self.radius + else: + radius = self.random_magnitude() + + img = results['img'] + pil_img = Image.fromarray(img) + pil_img = pil_img.filter(ImageFilter.GaussianBlur(radius=radius)) + results['img'] = np.array(pil_img, dtype=img.dtype) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(radius={self.radius}, ' + repr_str += f'prob={self.prob}{self.extra_repr()})' + return repr_str + + +# yapf: disable +# flake8: noqa +AUTOAUG_POLICIES = { + # Policy for ImageNet, refers to + # https://github.com/DeepVoltaire/AutoAugment/blame/master/autoaugment.py + 'imagenet': [ + [dict(type='Posterize', bits=4, prob=0.4), dict(type='Rotate', angle=30., prob=0.6)], + [dict(type='Solarize', thr=256 / 9 * 4, prob=0.6), dict(type='AutoContrast', prob=0.6)], + [dict(type='Equalize', prob=0.8), dict(type='Equalize', prob=0.6)], + [dict(type='Posterize', bits=5, prob=0.6), dict(type='Posterize', bits=5, prob=0.6)], + [dict(type='Equalize', prob=0.4), dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)], + [dict(type='Equalize', prob=0.4), dict(type='Rotate', angle=30 / 9 * 8, prob=0.8)], + [dict(type='Solarize', thr=256 / 9 * 6, prob=0.6), dict(type='Equalize', prob=0.6)], + [dict(type='Posterize', bits=6, prob=0.8), dict(type='Equalize', prob=1.)], + [dict(type='Rotate', angle=10., prob=0.2), dict(type='Solarize', thr=256 / 9, prob=0.6)], + [dict(type='Equalize', prob=0.6), dict(type='Posterize', bits=5, prob=0.4)], + [dict(type='Rotate', angle=30 / 9 * 8, prob=0.8), dict(type='ColorTransform', magnitude=0., prob=0.4)], + [dict(type='Rotate', angle=30., prob=0.4), dict(type='Equalize', prob=0.6)], + [dict(type='Equalize', prob=0.0), dict(type='Equalize', prob=0.8)], + [dict(type='Invert', prob=0.6), dict(type='Equalize', prob=1.)], + [dict(type='ColorTransform', magnitude=0.4, prob=0.6), dict(type='Contrast', magnitude=0.8, prob=1.)], + [dict(type='Rotate', angle=30 / 9 * 8, prob=0.8), dict(type='ColorTransform', magnitude=0.2, prob=1.)], + [dict(type='ColorTransform', magnitude=0.8, prob=0.8), dict(type='Solarize', thr=256 / 9 * 2, prob=0.8)], + [dict(type='Sharpness', magnitude=0.7, prob=0.4), dict(type='Invert', prob=0.6)], + [dict(type='Shear', magnitude=0.3 / 9 * 5, prob=0.6, direction='horizontal'), dict(type='Equalize', prob=1.)], + [dict(type='ColorTransform', magnitude=0., prob=0.4), dict(type='Equalize', prob=0.6)], + [dict(type='Equalize', prob=0.4), dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)], + [dict(type='Solarize', thr=256 / 9 * 4, prob=0.6), dict(type='AutoContrast', prob=0.6)], + [dict(type='Invert', prob=0.6), dict(type='Equalize', prob=1.)], + [dict(type='ColorTransform', magnitude=0.4, prob=0.6), dict(type='Contrast', magnitude=0.8, prob=1.)], + [dict(type='Equalize', prob=0.8), dict(type='Equalize', prob=0.6)], + ], +} + +RANDAUG_POLICIES = { + # Refers to `_RAND_INCREASING_TRANSFORMS` in pytorch-image-models + 'timm_increasing': [ + dict(type='AutoContrast'), + dict(type='Equalize'), + dict(type='Invert'), + dict(type='Rotate', magnitude_range=(0, 30)), + dict(type='Posterize', magnitude_range=(4, 0)), + dict(type='Solarize', magnitude_range=(256, 0)), + dict(type='SolarizeAdd', magnitude_range=(0, 110)), + dict(type='ColorTransform', magnitude_range=(0, 0.9)), + dict(type='Contrast', magnitude_range=(0, 0.9)), + dict(type='Brightness', magnitude_range=(0, 0.9)), + dict(type='Sharpness', magnitude_range=(0, 0.9)), + dict(type='Shear', magnitude_range=(0, 0.3), direction='horizontal'), + dict(type='Shear', magnitude_range=(0, 0.3), direction='vertical'), + dict(type='Translate', magnitude_range=(0, 0.45), direction='horizontal'), + dict(type='Translate', magnitude_range=(0, 0.45), direction='vertical'), + ], + 'simple_increasing': [ + dict(type='AutoContrast'), + dict(type='Equalize'), + dict(type='Rotate', magnitude_range=(0, 30)), + dict(type='Shear', magnitude_range=(0, 0.3), direction='horizontal'), + dict(type='Shear', magnitude_range=(0, 0.3), direction='vertical'), + ], +} diff --git a/mmpretrain/datasets/transforms/formatting.py b/mmpretrain/datasets/transforms/formatting.py new file mode 100644 index 0000000..e4d3316 --- /dev/null +++ b/mmpretrain/datasets/transforms/formatting.py @@ -0,0 +1,353 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from collections.abc import Sequence + +import cv2 +import numpy as np +import torch +import torchvision.transforms.functional as F +from mmcv.transforms import BaseTransform +from mmengine.utils import is_str +from PIL import Image + +from mmpretrain.registry import TRANSFORMS +from mmpretrain.structures import DataSample, MultiTaskDataSample + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + """ + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError( + f'Type {type(data)} cannot be converted to tensor.' + 'Supported types are: `numpy.ndarray`, `torch.Tensor`, ' + '`Sequence`, `int` and `float`') + + +@TRANSFORMS.register_module() +class PackInputs(BaseTransform): + """Pack the inputs data. + + **Required Keys:** + + - ``input_key`` + - ``*algorithm_keys`` + - ``*meta_keys`` + + **Deleted Keys:** + + All other keys in the dict. + + **Added Keys:** + + - inputs (:obj:`torch.Tensor`): The forward data of models. + - data_samples (:obj:`~mmpretrain.structures.DataSample`): The + annotation info of the sample. + + Args: + input_key (str): The key of element to feed into the model forwarding. + Defaults to 'img'. + algorithm_keys (Sequence[str]): The keys of custom elements to be used + in the algorithm. Defaults to an empty tuple. + meta_keys (Sequence[str]): The keys of meta information to be saved in + the data sample. Defaults to :attr:`PackInputs.DEFAULT_META_KEYS`. + + .. admonition:: Default algorithm keys + + Besides the specified ``algorithm_keys``, we will set some default keys + into the output data sample and do some formatting. Therefore, you + don't need to set these keys in the ``algorithm_keys``. + + - ``gt_label``: The ground-truth label. The value will be converted + into a 1-D tensor. + - ``gt_score``: The ground-truth score. The value will be converted + into a 1-D tensor. + - ``mask``: The mask for some self-supervise tasks. The value will + be converted into a tensor. + + .. admonition:: Default meta keys + + - ``sample_idx``: The id of the image sample. + - ``img_path``: The path to the image file. + - ``ori_shape``: The original shape of the image as a tuple (H, W). + - ``img_shape``: The shape of the image after the pipeline as a + tuple (H, W). + - ``scale_factor``: The scale factor between the resized image and + the original image. + - ``flip``: A boolean indicating if image flip transform was used. + - ``flip_direction``: The flipping direction. + """ + + DEFAULT_META_KEYS = ('sample_idx', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction') + + def __init__(self, + input_key='img', + algorithm_keys=(), + meta_keys=DEFAULT_META_KEYS): + self.input_key = input_key + self.algorithm_keys = algorithm_keys + self.meta_keys = meta_keys + + @staticmethod + def format_input(input_): + if isinstance(input_, list): + return [PackInputs.format_input(item) for item in input_] + elif isinstance(input_, np.ndarray): + if input_.ndim == 2: # For grayscale image. + input_ = np.expand_dims(input_, -1) + if input_.ndim == 3 and not input_.flags.c_contiguous: + input_ = np.ascontiguousarray(input_.transpose(2, 0, 1)) + input_ = to_tensor(input_) + elif input_.ndim == 3: + # convert to tensor first to accelerate, see + # https://github.com/open-mmlab/mmdetection/pull/9533 + input_ = to_tensor(input_).permute(2, 0, 1).contiguous() + else: + # convert input with other shape to tensor without permute, + # like video input (num_crops, C, T, H, W). + input_ = to_tensor(input_) + elif isinstance(input_, Image.Image): + input_ = F.pil_to_tensor(input_) + elif not isinstance(input_, torch.Tensor): + raise TypeError(f'Unsupported input type {type(input_)}.') + + return input_ + + def transform(self, results: dict) -> dict: + """Method to pack the input data.""" + + packed_results = dict() + if self.input_key in results: + input_ = results[self.input_key] + packed_results['inputs'] = self.format_input(input_) + + data_sample = DataSample() + + # Set default keys + if 'gt_label' in results: + data_sample.set_gt_label(results['gt_label']) + if 'gt_score' in results: + data_sample.set_gt_score(results['gt_score']) + if 'mask' in results: + data_sample.set_mask(results['mask']) + + # Set custom algorithm keys + for key in self.algorithm_keys: + if key in results: + data_sample.set_field(results[key], key) + + # Set meta keys + for key in self.meta_keys: + if key in results: + data_sample.set_field(results[key], key, field_type='metainfo') + + packed_results['data_samples'] = data_sample + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(input_key='{self.input_key}', " + repr_str += f'algorithm_keys={self.algorithm_keys}, ' + repr_str += f'meta_keys={self.meta_keys})' + return repr_str + + +@TRANSFORMS.register_module() +class PackMultiTaskInputs(BaseTransform): + """Convert all image labels of multi-task dataset to a dict of tensor. + + Args: + multi_task_fields (Sequence[str]): + input_key (str): + task_handlers (dict): + """ + + def __init__(self, + multi_task_fields, + input_key='img', + task_handlers=dict()): + self.multi_task_fields = multi_task_fields + self.input_key = input_key + self.task_handlers = defaultdict(PackInputs) + for task_name, task_handler in task_handlers.items(): + self.task_handlers[task_name] = TRANSFORMS.build(task_handler) + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + + result = {'img_path': 'a.png', 'gt_label': {'task1': 1, 'task3': 3}, + 'img': array([[[ 0, 0, 0]) + """ + packed_results = dict() + results = results.copy() + + if self.input_key in results: + input_ = results[self.input_key] + packed_results['inputs'] = PackInputs.format_input(input_) + + task_results = defaultdict(dict) + for field in self.multi_task_fields: + if field in results: + value = results.pop(field) + for k, v in value.items(): + task_results[k].update({field: v}) + + data_sample = MultiTaskDataSample() + for task_name, task_result in task_results.items(): + task_handler = self.task_handlers[task_name] + task_pack_result = task_handler({**results, **task_result}) + data_sample.set_field(task_pack_result['data_samples'], task_name) + + packed_results['data_samples'] = data_sample + return packed_results + + def __repr__(self): + repr = self.__class__.__name__ + task_handlers = ', '.join( + f"'{name}': {handler.__class__.__name__}" + for name, handler in self.task_handlers.items()) + repr += f'(multi_task_fields={self.multi_task_fields}, ' + repr += f"input_key='{self.input_key}', " + repr += f'task_handlers={{{task_handlers}}})' + return repr + + +@TRANSFORMS.register_module() +class Transpose(BaseTransform): + """Transpose numpy array. + + **Required Keys:** + + - ``*keys`` + + **Modified Keys:** + + - ``*keys`` + + Args: + keys (List[str]): The fields to convert to tensor. + order (List[int]): The output dimensions order. + """ + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def transform(self, results): + """Method to transpose array.""" + for key in self.keys: + results[key] = results[key].transpose(self.order) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, order={self.order})' + + +@TRANSFORMS.register_module(('NumpyToPIL', 'ToPIL')) +class NumpyToPIL(BaseTransform): + """Convert the image from OpenCV format to :obj:`PIL.Image.Image`. + + **Required Keys:** + + - ``img`` + + **Modified Keys:** + + - ``img`` + + Args: + to_rgb (bool): Whether to convert img to rgb. Defaults to True. + """ + + def __init__(self, to_rgb: bool = False) -> None: + self.to_rgb = to_rgb + + def transform(self, results: dict) -> dict: + """Method to convert images to :obj:`PIL.Image.Image`.""" + img = results['img'] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img + + results['img'] = Image.fromarray(img) + return results + + def __repr__(self) -> str: + return self.__class__.__name__ + f'(to_rgb={self.to_rgb})' + + +@TRANSFORMS.register_module(('PILToNumpy', 'ToNumpy')) +class PILToNumpy(BaseTransform): + """Convert img to :obj:`numpy.ndarray`. + + **Required Keys:** + + - ``img`` + + **Modified Keys:** + + - ``img`` + + Args: + to_bgr (bool): Whether to convert img to rgb. Defaults to True. + dtype (str, optional): The dtype of the converted numpy array. + Defaults to None. + """ + + def __init__(self, to_bgr: bool = False, dtype=None) -> None: + self.to_bgr = to_bgr + self.dtype = dtype + + def transform(self, results: dict) -> dict: + """Method to convert img to :obj:`numpy.ndarray`.""" + img = np.array(results['img'], dtype=self.dtype) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if self.to_bgr else img + + results['img'] = img + return results + + def __repr__(self) -> str: + return self.__class__.__name__ + \ + f'(to_bgr={self.to_bgr}, dtype={self.dtype})' + + +@TRANSFORMS.register_module() +class Collect(BaseTransform): + """Collect and only reserve the specified fields. + + **Required Keys:** + + - ``*keys`` + + **Deleted Keys:** + + All keys except those in the argument ``*keys``. + + Args: + keys (Sequence[str]): The keys of the fields to be collected. + """ + + def __init__(self, keys): + self.keys = keys + + def transform(self, results): + data = {} + for key in self.keys: + data[key] = results[key] + return data + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' diff --git a/mmpretrain/datasets/transforms/processing.py b/mmpretrain/datasets/transforms/processing.py new file mode 100644 index 0000000..4c640f6 --- /dev/null +++ b/mmpretrain/datasets/transforms/processing.py @@ -0,0 +1,1795 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import math +import numbers +import re +import string +from enum import EnumMeta +from numbers import Number +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import mmcv +import mmengine +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional as F +from mmcv.transforms import BaseTransform +from mmcv.transforms.utils import cache_randomness +from PIL import Image +from torchvision import transforms +from torchvision.transforms.transforms import InterpolationMode + +from mmpretrain.registry import TRANSFORMS + +try: + import albumentations +except ImportError: + albumentations = None + + +def _str_to_torch_dtype(t: str): + """mapping str format dtype to torch.dtype.""" + import torch # noqa: F401,F403 + return eval(f'torch.{t}') + + +def _interpolation_modes_from_str(t: str): + """mapping str format to Interpolation.""" + t = t.lower() + inverse_modes_mapping = { + 'nearest': InterpolationMode.NEAREST, + 'bilinear': InterpolationMode.BILINEAR, + 'bicubic': InterpolationMode.BICUBIC, + 'box': InterpolationMode.BOX, + 'hammimg': InterpolationMode.HAMMING, + 'lanczos': InterpolationMode.LANCZOS, + } + return inverse_modes_mapping[t] + + +class TorchVisonTransformWrapper: + + def __init__(self, transform, *args, **kwargs): + if 'interpolation' in kwargs and isinstance(kwargs['interpolation'], + str): + kwargs['interpolation'] = _interpolation_modes_from_str( + kwargs['interpolation']) + if 'dtype' in kwargs and isinstance(kwargs['dtype'], str): + kwargs['dtype'] = _str_to_torch_dtype(kwargs['dtype']) + self.t = transform(*args, **kwargs) + + def __call__(self, results): + results['img'] = self.t(results['img']) + return results + + def __repr__(self) -> str: + return f'TorchVision{repr(self.t)}' + + +def register_vision_transforms() -> List[str]: + """Register transforms in ``torchvision.transforms`` to the ``TRANSFORMS`` + registry. + + Returns: + List[str]: A list of registered transforms' name. + """ + vision_transforms = [] + for module_name in dir(torchvision.transforms): + if not re.match('[A-Z]', module_name): + # must startswith a capital letter + continue + _transform = getattr(torchvision.transforms, module_name) + if inspect.isclass(_transform) and callable( + _transform) and not isinstance(_transform, (EnumMeta)): + from functools import partial + TRANSFORMS.register_module( + module=partial( + TorchVisonTransformWrapper, transform=_transform), + name=f'torchvision/{module_name}') + vision_transforms.append(f'torchvision/{module_name}') + return vision_transforms + + +# register all the transforms in torchvision by using a transform wrapper +VISION_TRANSFORMS = register_vision_transforms() + + +@TRANSFORMS.register_module() +class RandomCrop(BaseTransform): + """Crop the given Image at a random location. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + crop_size (int | Sequence): Desired output size of the crop. If + crop_size is an int instead of sequence like (h, w), a square crop + (crop_size, crop_size) is made. + padding (int | Sequence, optional): Optional padding on each border + of the image. If a sequence of length 4 is provided, it is used to + pad left, top, right, bottom borders respectively. If a sequence + of length 2 is provided, it is used to pad left/right, top/bottom + borders, respectively. Default: None, which means no padding. + pad_if_needed (bool): It will pad the image if smaller than the + desired size to avoid raising an exception. Since cropping is done + after padding, the padding seems to be done at a random offset. + Default: False. + pad_val (Number | Sequence[Number]): Pixel pad_val value for constant + fill. If a tuple of length 3, it is used to pad_val R, G, B + channels respectively. Default: 0. + padding_mode (str): Type of padding. Defaults to "constant". Should + be one of the following: + + - ``constant``: Pads with a constant value, this value is specified + with pad_val. + - ``edge``: pads with the last value at the edge of the image. + - ``reflect``: Pads with reflection of image without repeating the + last value on the edge. For example, padding [1, 2, 3, 4] + with 2 elements on both sides in reflect mode will result + in [3, 2, 1, 2, 3, 4, 3, 2]. + - ``symmetric``: Pads with reflection of image repeating the last + value on the edge. For example, padding [1, 2, 3, 4] with + 2 elements on both sides in symmetric mode will result in + [2, 1, 1, 2, 3, 4, 4, 3]. + """ + + def __init__(self, + crop_size: Union[Sequence, int], + padding: Optional[Union[Sequence, int]] = None, + pad_if_needed: bool = False, + pad_val: Union[Number, Sequence[Number]] = 0, + padding_mode: str = 'constant'): + if isinstance(crop_size, Sequence): + assert len(crop_size) == 2 + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + else: + assert crop_size > 0 + self.crop_size = (crop_size, crop_size) + # check padding mode + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] + self.padding = padding + self.pad_if_needed = pad_if_needed + self.pad_val = pad_val + self.padding_mode = padding_mode + + @cache_randomness + def rand_crop_params(self, img: np.ndarray): + """Get parameters for ``crop`` for a random crop. + + Args: + img (ndarray): Image to be cropped. + + Returns: + tuple: Params (offset_h, offset_w, target_h, target_w) to be + passed to ``crop`` for random crop. + """ + h, w = img.shape[:2] + target_h, target_w = self.crop_size + if w == target_w and h == target_h: + return 0, 0, h, w + elif w < target_w or h < target_h: + target_w = min(w, target_w) + target_h = min(h, target_h) + + offset_h = np.random.randint(0, h - target_h + 1) + offset_w = np.random.randint(0, w - target_w + 1) + + return offset_h, offset_w, target_h, target_w + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' + key in result dict is updated according to crop size. + """ + img = results['img'] + if self.padding is not None: + img = mmcv.impad(img, padding=self.padding, pad_val=self.pad_val) + + # pad img if needed + if self.pad_if_needed: + h_pad = math.ceil(max(0, self.crop_size[0] - img.shape[0]) / 2) + w_pad = math.ceil(max(0, self.crop_size[1] - img.shape[1]) / 2) + + img = mmcv.impad( + img, + padding=(w_pad, h_pad, w_pad, h_pad), + pad_val=self.pad_val, + padding_mode=self.padding_mode) + + offset_h, offset_w, target_h, target_w = self.rand_crop_params(img) + img = mmcv.imcrop( + img, + np.array([ + offset_w, + offset_h, + offset_w + target_w - 1, + offset_h + target_h - 1, + ])) + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}' + repr_str += f', padding={self.padding}' + repr_str += f', pad_if_needed={self.pad_if_needed}' + repr_str += f', pad_val={self.pad_val}' + repr_str += f', padding_mode={self.padding_mode})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomResizedCrop(BaseTransform): + """Crop the given image to random scale and aspect ratio. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a + random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio + is made. This crop is finally resized to given size. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + scale (sequence | int): Desired output scale of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + crop_ratio_range (tuple): Range of the random size of the cropped + image compared to the original image. Defaults to (0.08, 1.0). + aspect_ratio_range (tuple): Range of the random aspect ratio of the + cropped image compared to the original image. + Defaults to (3. / 4., 4. / 3.). + max_attempts (int): Maximum number of attempts before falling back to + Central Crop. Defaults to 10. + interpolation (str): Interpolation method, accepted values are + 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to + 'bilinear'. + backend (str): The image resize backend type, accepted values are + 'cv2' and 'pillow'. Defaults to 'cv2'. + """ + + def __init__(self, + scale: Union[Sequence, int], + crop_ratio_range: Tuple[float, float] = (0.08, 1.0), + aspect_ratio_range: Tuple[float, float] = (3. / 4., 4. / 3.), + max_attempts: int = 10, + interpolation: str = 'bilinear', + backend: str = 'cv2') -> None: + if isinstance(scale, Sequence): + assert len(scale) == 2 + assert scale[0] > 0 and scale[1] > 0 + self.scale = scale + else: + assert scale > 0 + self.scale = (scale, scale) + if (crop_ratio_range[0] > crop_ratio_range[1]) or ( + aspect_ratio_range[0] > aspect_ratio_range[1]): + raise ValueError( + 'range should be of kind (min, max). ' + f'But received crop_ratio_range {crop_ratio_range} ' + f'and aspect_ratio_range {aspect_ratio_range}.') + assert isinstance(max_attempts, int) and max_attempts >= 0, \ + 'max_attempts mush be int and no less than 0.' + assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area', + 'lanczos') + + self.crop_ratio_range = crop_ratio_range + self.aspect_ratio_range = aspect_ratio_range + self.max_attempts = max_attempts + self.interpolation = interpolation + self.backend = backend + + @cache_randomness + def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (ndarray): Image to be cropped. + + Returns: + tuple: Params (offset_h, offset_w, target_h, target_w) to be + passed to `crop` for a random sized crop. + """ + h, w = img.shape[:2] + area = h * w + + for _ in range(self.max_attempts): + target_area = np.random.uniform(*self.crop_ratio_range) * area + log_ratio = (math.log(self.aspect_ratio_range[0]), + math.log(self.aspect_ratio_range[1])) + aspect_ratio = math.exp(np.random.uniform(*log_ratio)) + target_w = int(round(math.sqrt(target_area * aspect_ratio))) + target_h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < target_w <= w and 0 < target_h <= h: + offset_h = np.random.randint(0, h - target_h + 1) + offset_w = np.random.randint(0, w - target_w + 1) + + return offset_h, offset_w, target_h, target_w + + # Fallback to central crop + in_ratio = float(w) / float(h) + if in_ratio < min(self.aspect_ratio_range): + target_w = w + target_h = int(round(target_w / min(self.aspect_ratio_range))) + elif in_ratio > max(self.aspect_ratio_range): + target_h = h + target_w = int(round(target_h * max(self.aspect_ratio_range))) + else: # whole image + target_w = w + target_h = h + offset_h = (h - target_h) // 2 + offset_w = (w - target_w) // 2 + return offset_h, offset_w, target_h, target_w + + def transform(self, results: dict) -> dict: + """Transform function to randomly resized crop images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly resized cropped results, 'img_shape' + key in result dict is updated according to crop size. + """ + img = results['img'] + offset_h, offset_w, target_h, target_w = self.rand_crop_params(img) + img = mmcv.imcrop( + img, + bboxes=np.array([ + offset_w, offset_h, offset_w + target_w - 1, + offset_h + target_h - 1 + ])) + img = mmcv.imresize( + img, + tuple(self.scale[::-1]), + interpolation=self.interpolation, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(scale={self.scale}' + repr_str += ', crop_ratio_range=' + repr_str += f'{tuple(round(s, 4) for s in self.crop_ratio_range)}' + repr_str += ', aspect_ratio_range=' + repr_str += f'{tuple(round(r, 4) for r in self.aspect_ratio_range)}' + repr_str += f', max_attempts={self.max_attempts}' + repr_str += f', interpolation={self.interpolation}' + repr_str += f', backend={self.backend})' + return repr_str + + +@TRANSFORMS.register_module() +class EfficientNetRandomCrop(RandomResizedCrop): + """EfficientNet style RandomResizedCrop. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + scale (int): Desired output scale of the crop. Only int size is + accepted, a square crop (size, size) is made. + min_covered (Number): Minimum ratio of the cropped area to the original + area. Defaults to 0.1. + crop_padding (int): The crop padding parameter in efficientnet style + center crop. Defaults to 32. + crop_ratio_range (tuple): Range of the random size of the cropped + image compared to the original image. Defaults to (0.08, 1.0). + aspect_ratio_range (tuple): Range of the random aspect ratio of the + cropped image compared to the original image. + Defaults to (3. / 4., 4. / 3.). + max_attempts (int): Maximum number of attempts before falling back to + Central Crop. Defaults to 10. + interpolation (str): Interpolation method, accepted values are + 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to + 'bicubic'. + backend (str): The image resize backend type, accepted values are + 'cv2' and 'pillow'. Defaults to 'cv2'. + """ + + def __init__(self, + scale: int, + min_covered: float = 0.1, + crop_padding: int = 32, + interpolation: str = 'bicubic', + **kwarg): + assert isinstance(scale, int) + super().__init__(scale, interpolation=interpolation, **kwarg) + assert min_covered >= 0, 'min_covered should be no less than 0.' + assert crop_padding >= 0, 'crop_padding should be no less than 0.' + + self.min_covered = min_covered + self.crop_padding = crop_padding + + # https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/data.py # noqa + @cache_randomness + def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (ndarray): Image to be cropped. + + Returns: + tuple: Params (offset_h, offset_w, target_h, target_w) to be + passed to `crop` for a random sized crop. + """ + h, w = img.shape[:2] + area = h * w + min_target_area = self.crop_ratio_range[0] * area + max_target_area = self.crop_ratio_range[1] * area + + for _ in range(self.max_attempts): + aspect_ratio = np.random.uniform(*self.aspect_ratio_range) + min_target_h = int( + round(math.sqrt(min_target_area / aspect_ratio))) + max_target_h = int( + round(math.sqrt(max_target_area / aspect_ratio))) + + if max_target_h * aspect_ratio > w: + max_target_h = int((w + 0.5 - 1e-7) / aspect_ratio) + if max_target_h * aspect_ratio > w: + max_target_h -= 1 + + max_target_h = min(max_target_h, h) + min_target_h = min(max_target_h, min_target_h) + + # slightly differs from tf implementation + target_h = int( + round(np.random.uniform(min_target_h, max_target_h))) + target_w = int(round(target_h * aspect_ratio)) + target_area = target_h * target_w + + # slight differs from tf. In tf, if target_area > max_target_area, + # area will be recalculated + if (target_area < min_target_area or target_area > max_target_area + or target_w > w or target_h > h + or target_area < self.min_covered * area): + continue + + offset_h = np.random.randint(0, h - target_h + 1) + offset_w = np.random.randint(0, w - target_w + 1) + + return offset_h, offset_w, target_h, target_w + + # Fallback to central crop + img_short = min(h, w) + crop_size = self.scale[0] / (self.scale[0] + + self.crop_padding) * img_short + + offset_h = max(0, int(round((h - crop_size) / 2.))) + offset_w = max(0, int(round((w - crop_size) / 2.))) + return offset_h, offset_w, crop_size, crop_size + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = super().__repr__()[:-1] + repr_str += f', min_covered={self.min_covered}' + repr_str += f', crop_padding={self.crop_padding})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomErasing(BaseTransform): + """Randomly selects a rectangle region in an image and erase pixels. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + + Args: + erase_prob (float): Probability that image will be randomly erased. + Default: 0.5 + min_area_ratio (float): Minimum erased area / input image area + Default: 0.02 + max_area_ratio (float): Maximum erased area / input image area + Default: 0.4 + aspect_range (sequence | float): Aspect ratio range of erased area. + if float, it will be converted to (aspect_ratio, 1/aspect_ratio) + Default: (3/10, 10/3) + mode (str): Fill method in erased area, can be: + + - const (default): All pixels are assign with the same value. + - rand: each pixel is assigned with a random value in [0, 255] + + fill_color (sequence | Number): Base color filled in erased area. + Defaults to (128, 128, 128). + fill_std (sequence | Number, optional): If set and ``mode`` is 'rand', + fill erased area with random color from normal distribution + (mean=fill_color, std=fill_std); If not set, fill erased area with + random color from uniform distribution (0~255). Defaults to None. + + Note: + See `Random Erasing Data Augmentation + `_ + + This paper provided 4 modes: RE-R, RE-M, RE-0, RE-255, and use RE-M as + default. The config of these 4 modes are: + + - RE-R: RandomErasing(mode='rand') + - RE-M: RandomErasing(mode='const', fill_color=(123.67, 116.3, 103.5)) + - RE-0: RandomErasing(mode='const', fill_color=0) + - RE-255: RandomErasing(mode='const', fill_color=255) + """ + + def __init__(self, + erase_prob=0.5, + min_area_ratio=0.02, + max_area_ratio=0.4, + aspect_range=(3 / 10, 10 / 3), + mode='const', + fill_color=(128, 128, 128), + fill_std=None): + assert isinstance(erase_prob, float) and 0. <= erase_prob <= 1. + assert isinstance(min_area_ratio, float) and 0. <= min_area_ratio <= 1. + assert isinstance(max_area_ratio, float) and 0. <= max_area_ratio <= 1. + assert min_area_ratio <= max_area_ratio, \ + 'min_area_ratio should be smaller than max_area_ratio' + if isinstance(aspect_range, float): + aspect_range = min(aspect_range, 1 / aspect_range) + aspect_range = (aspect_range, 1 / aspect_range) + assert isinstance(aspect_range, Sequence) and len(aspect_range) == 2 \ + and all(isinstance(x, float) for x in aspect_range), \ + 'aspect_range should be a float or Sequence with two float.' + assert all(x > 0 for x in aspect_range), \ + 'aspect_range should be positive.' + assert aspect_range[0] <= aspect_range[1], \ + 'In aspect_range (min, max), min should be smaller than max.' + assert mode in ['const', 'rand'], \ + 'Please select `mode` from ["const", "rand"].' + if isinstance(fill_color, Number): + fill_color = [fill_color] * 3 + assert isinstance(fill_color, Sequence) and len(fill_color) == 3 \ + and all(isinstance(x, Number) for x in fill_color), \ + 'fill_color should be a float or Sequence with three int.' + if fill_std is not None: + if isinstance(fill_std, Number): + fill_std = [fill_std] * 3 + assert isinstance(fill_std, Sequence) and len(fill_std) == 3 \ + and all(isinstance(x, Number) for x in fill_std), \ + 'fill_std should be a float or Sequence with three int.' + + self.erase_prob = erase_prob + self.min_area_ratio = min_area_ratio + self.max_area_ratio = max_area_ratio + self.aspect_range = aspect_range + self.mode = mode + self.fill_color = fill_color + self.fill_std = fill_std + + def _fill_pixels(self, img, top, left, h, w): + """Fill pixels to the patch of image.""" + if self.mode == 'const': + patch = np.empty((h, w, 3), dtype=np.uint8) + patch[:, :] = np.array(self.fill_color, dtype=np.uint8) + elif self.fill_std is None: + # Uniform distribution + patch = np.random.uniform(0, 256, (h, w, 3)).astype(np.uint8) + else: + # Normal distribution + patch = np.random.normal(self.fill_color, self.fill_std, (h, w, 3)) + patch = np.clip(patch.astype(np.int32), 0, 255).astype(np.uint8) + + img[top:top + h, left:left + w] = patch + return img + + @cache_randomness + def random_disable(self): + """Randomly disable the transform.""" + return np.random.rand() > self.erase_prob + + @cache_randomness + def random_patch(self, img_h, img_w): + """Randomly generate patch the erase.""" + # convert the aspect ratio to log space to equally handle width and + # height. + log_aspect_range = np.log( + np.array(self.aspect_range, dtype=np.float32)) + aspect_ratio = np.exp(np.random.uniform(*log_aspect_range)) + area = img_h * img_w + area *= np.random.uniform(self.min_area_ratio, self.max_area_ratio) + + h = min(int(round(np.sqrt(area * aspect_ratio))), img_h) + w = min(int(round(np.sqrt(area / aspect_ratio))), img_w) + top = np.random.randint(0, img_h - h) if img_h > h else 0 + left = np.random.randint(0, img_w - w) if img_w > w else 0 + return top, left, h, w + + def transform(self, results): + """ + Args: + results (dict): Results dict from pipeline + + Returns: + dict: Results after the transformation. + """ + if self.random_disable(): + return results + + img = results['img'] + img_h, img_w = img.shape[:2] + + img = self._fill_pixels(img, *self.random_patch(img_h, img_w)) + + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(erase_prob={self.erase_prob}, ' + repr_str += f'min_area_ratio={self.min_area_ratio}, ' + repr_str += f'max_area_ratio={self.max_area_ratio}, ' + repr_str += f'aspect_range={self.aspect_range}, ' + repr_str += f'mode={self.mode}, ' + repr_str += f'fill_color={self.fill_color}, ' + repr_str += f'fill_std={self.fill_std})' + return repr_str + + +@TRANSFORMS.register_module() +class EfficientNetCenterCrop(BaseTransform): + r"""EfficientNet style center crop. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Args: + crop_size (int): Expected size after cropping with the format + of (h, w). + crop_padding (int): The crop padding parameter in efficientnet style + center crop. Defaults to 32. + interpolation (str): Interpolation method, accepted values are + 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if + ``efficientnet_style`` is True. Defaults to 'bicubic'. + backend (str): The image resize backend type, accepted values are + `cv2` and `pillow`. Only valid if efficientnet style is True. + Defaults to `cv2`. + Notes: + - If the image is smaller than the crop size, return the original + image. + - The pipeline will be to first + to perform the center crop with the ``crop_size_`` as: + + .. math:: + + \text{crop_size_} = \frac{\text{crop_size}}{\text{crop_size} + + \text{crop_padding}} \times \text{short_edge} + + And then the pipeline resizes the img to the input crop size. + """ + + def __init__(self, + crop_size: int, + crop_padding: int = 32, + interpolation: str = 'bicubic', + backend: str = 'cv2'): + assert isinstance(crop_size, int) + assert crop_size > 0 + assert crop_padding >= 0 + assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area', + 'lanczos') + + self.crop_size = crop_size + self.crop_padding = crop_padding + self.interpolation = interpolation + self.backend = backend + + def transform(self, results: dict) -> dict: + """Transform function to randomly resized crop images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: EfficientNet style center cropped results, 'img_shape' + key in result dict is updated according to crop size. + """ + img = results['img'] + h, w = img.shape[:2] + + # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L118 # noqa + img_short = min(h, w) + crop_size = self.crop_size / (self.crop_size + + self.crop_padding) * img_short + + offset_h = max(0, int(round((h - crop_size) / 2.))) + offset_w = max(0, int(round((w - crop_size) / 2.))) + + # crop the image + img = mmcv.imcrop( + img, + bboxes=np.array([ + offset_w, offset_h, offset_w + crop_size - 1, + offset_h + crop_size - 1 + ])) + # resize image + img = mmcv.imresize( + img, (self.crop_size, self.crop_size), + interpolation=self.interpolation, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}' + repr_str += f', crop_padding={self.crop_padding}' + repr_str += f', interpolation={self.interpolation}' + repr_str += f', backend={self.backend})' + return repr_str + + +@TRANSFORMS.register_module() +class ResizeEdge(BaseTransform): + """Resize images along the specified edge. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + **Added Keys:** + + - scale + - scale_factor + + Args: + scale (int): The edge scale to resizing. + edge (str): The edge to resize. Defaults to 'short'. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. + Defaults to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + Defaults to 'bilinear'. + """ + + def __init__(self, + scale: int, + edge: str = 'short', + backend: str = 'cv2', + interpolation: str = 'bilinear') -> None: + allow_edges = ['short', 'long', 'width', 'height'] + assert edge in allow_edges, \ + f'Invalid edge "{edge}", please specify from {allow_edges}.' + self.edge = edge + self.scale = scale + self.backend = backend + self.interpolation = interpolation + + def _resize_img(self, results: dict) -> None: + """Resize images with ``results['scale']``.""" + + img, w_scale, h_scale = mmcv.imresize( + results['img'], + results['scale'], + interpolation=self.interpolation, + return_scale=True, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape[:2] + results['scale'] = img.shape[:2][::-1] + results['scale_factor'] = (w_scale, h_scale) + + def transform(self, results: Dict) -> Dict: + """Transform function to resize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img', 'scale', 'scale_factor', + 'img_shape' keys are updated in result dict. + """ + assert 'img' in results, 'No `img` field in the input.' + + h, w = results['img'].shape[:2] + if any([ + # conditions to resize the width + self.edge == 'short' and w < h, + self.edge == 'long' and w > h, + self.edge == 'width', + ]): + width = self.scale + height = int(self.scale * h / w) + else: + height = self.scale + width = int(self.scale * w / h) + results['scale'] = (width, height) + + self._resize_img(results) + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(scale={self.scale}, ' + repr_str += f'edge={self.edge}, ' + repr_str += f'backend={self.backend}, ' + repr_str += f'interpolation={self.interpolation})' + return repr_str + + +@TRANSFORMS.register_module() +class ColorJitter(BaseTransform): + """Randomly change the brightness, contrast and saturation of an image. + + Modified from + https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py + Licensed under the BSD 3-Clause License. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + + Args: + brightness (float | Sequence[float] (min, max)): How much to jitter + brightness. brightness_factor is chosen uniformly from + ``[max(0, 1 - brightness), 1 + brightness]`` or the given + ``[min, max]``. Should be non negative numbers. Defaults to 0. + contrast (float | Sequence[float] (min, max)): How much to jitter + contrast. contrast_factor is chosen uniformly from + ``[max(0, 1 - contrast), 1 + contrast]`` or the given + ``[min, max]``. Should be non negative numbers. Defaults to 0. + saturation (float | Sequence[float] (min, max)): How much to jitter + saturation. saturation_factor is chosen uniformly from + ``[max(0, 1 - saturation), 1 + saturation]`` or the given + ``[min, max]``. Should be non negative numbers. Defaults to 0. + hue (float | Sequence[float] (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from ``[-hue, hue]`` (0 <= hue + <= 0.5) or the given ``[min, max]`` (-0.5 <= min <= max <= 0.5). + Defaults to 0. + backend (str): The backend to operate the image. Defaults to 'pillow' + """ + + def __init__(self, + brightness: Union[float, Sequence[float]] = 0., + contrast: Union[float, Sequence[float]] = 0., + saturation: Union[float, Sequence[float]] = 0., + hue: Union[float, Sequence[float]] = 0., + backend='pillow'): + self.brightness = self._set_range(brightness, 'brightness') + self.contrast = self._set_range(contrast, 'contrast') + self.saturation = self._set_range(saturation, 'saturation') + self.hue = self._set_range(hue, 'hue', center=0, bound=(-0.5, 0.5)) + self.backend = backend + + def _set_range(self, value, name, center=1, bound=(0, float('inf'))): + """Set the range of magnitudes.""" + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + f'If {name} is a single number, it must be non negative.') + value = (center - float(value), center + float(value)) + + if isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + value = np.clip(value, bound[0], bound[1]) + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.warning(f'ColorJitter {name} values exceed the bound ' + f'{bound}, clipped to the bound.') + else: + raise TypeError(f'{name} should be a single number ' + 'or a list/tuple with length 2.') + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + else: + value = tuple(value) + + return value + + @cache_randomness + def _rand_params(self): + """Get random parameters including magnitudes and indices of + transforms.""" + trans_inds = np.random.permutation(4) + b, c, s, h = (None, ) * 4 + + if self.brightness is not None: + b = np.random.uniform(self.brightness[0], self.brightness[1]) + if self.contrast is not None: + c = np.random.uniform(self.contrast[0], self.contrast[1]) + if self.saturation is not None: + s = np.random.uniform(self.saturation[0], self.saturation[1]) + if self.hue is not None: + h = np.random.uniform(self.hue[0], self.hue[1]) + + return trans_inds, b, c, s, h + + def transform(self, results: Dict) -> Dict: + """Transform function to resize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: ColorJitter results, 'img' key is updated in result dict. + """ + img = results['img'] + trans_inds, brightness, contrast, saturation, hue = self._rand_params() + + for index in trans_inds: + if index == 0 and brightness is not None: + img = mmcv.adjust_brightness( + img, brightness, backend=self.backend) + elif index == 1 and contrast is not None: + img = mmcv.adjust_contrast(img, contrast, backend=self.backend) + elif index == 2 and saturation is not None: + img = mmcv.adjust_color( + img, alpha=saturation, backend=self.backend) + elif index == 3 and hue is not None: + img = mmcv.adjust_hue(img, hue, backend=self.backend) + + results['img'] = img + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(brightness={self.brightness}, ' + repr_str += f'contrast={self.contrast}, ' + repr_str += f'saturation={self.saturation}, ' + repr_str += f'hue={self.hue})' + return repr_str + + +@TRANSFORMS.register_module() +class Lighting(BaseTransform): + """Adjust images lighting using AlexNet-style PCA jitter. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + + Args: + eigval (Sequence[float]): the eigenvalue of the convariance matrix + of pixel values, respectively. + eigvec (list[list]): the eigenvector of the convariance matrix of + pixel values, respectively. + alphastd (float): The standard deviation for distribution of alpha. + Defaults to 0.1. + to_rgb (bool): Whether to convert img to rgb. Defaults to False. + """ + + def __init__(self, + eigval: Sequence[float], + eigvec: Sequence[float], + alphastd: float = 0.1, + to_rgb: bool = False): + assert isinstance(eigval, Sequence), \ + f'eigval must be Sequence, got {type(eigval)} instead.' + assert isinstance(eigvec, Sequence), \ + f'eigvec must be Sequence, got {type(eigvec)} instead.' + for vec in eigvec: + assert isinstance(vec, Sequence) and len(vec) == len(eigvec[0]), \ + 'eigvec must contains lists with equal length.' + assert isinstance(alphastd, float), 'alphastd should be of type ' \ + f'float or int, got {type(alphastd)} instead.' + + self.eigval = np.array(eigval) + self.eigvec = np.array(eigvec) + self.alphastd = alphastd + self.to_rgb = to_rgb + + def transform(self, results: Dict) -> Dict: + """Transform function to resize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Lightinged results, 'img' key is updated in result dict. + """ + assert 'img' in results, 'No `img` field in the input.' + + img = results['img'] + img_lighting = mmcv.adjust_lighting( + img, + self.eigval, + self.eigvec, + alphastd=self.alphastd, + to_rgb=self.to_rgb) + results['img'] = img_lighting.astype(img.dtype) + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(eigval={self.eigval.tolist()}, ' + repr_str += f'eigvec={self.eigvec.tolist()}, ' + repr_str += f'alphastd={self.alphastd}, ' + repr_str += f'to_rgb={self.to_rgb})' + return repr_str + + +# 'Albu' is used in previous versions of mmpretrain, here is for compatibility +# users can use both 'Albumentations' and 'Albu'. +@TRANSFORMS.register_module(['Albumentations', 'Albu']) +class Albumentations(BaseTransform): + """Wrapper to use augmentation from albumentations library. + + **Required Keys:** + + - img + + **Modified Keys:** + + - img + - img_shape + + Adds custom transformations from albumentations library. + More details can be found in + `Albumentations `_. + An example of ``transforms`` is as followed: + + .. code-block:: + + [ + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=0.5), + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + + Args: + transforms (List[Dict]): List of albumentations transform configs. + keymap (Optional[Dict]): Mapping of mmpretrain to albumentations + fields, in format {'input key':'albumentation-style key'}. + Defaults to None. + + Example: + >>> import mmcv + >>> from mmpretrain.datasets import Albumentations + >>> transforms = [ + ... dict( + ... type='ShiftScaleRotate', + ... shift_limit=0.0625, + ... scale_limit=0.0, + ... rotate_limit=0, + ... interpolation=1, + ... p=0.5), + ... dict( + ... type='RandomBrightnessContrast', + ... brightness_limit=[0.1, 0.3], + ... contrast_limit=[0.1, 0.3], + ... p=0.2), + ... dict(type='ChannelShuffle', p=0.1), + ... dict( + ... type='OneOf', + ... transforms=[ + ... dict(type='Blur', blur_limit=3, p=1.0), + ... dict(type='MedianBlur', blur_limit=3, p=1.0) + ... ], + ... p=0.1), + ... ] + >>> albu = Albumentations(transforms) + >>> data = {'img': mmcv.imread('./demo/demo.JPEG')} + >>> data = albu(data) + >>> print(data['img'].shape) + (375, 500, 3) + """ + + def __init__(self, transforms: List[Dict], keymap: Optional[Dict] = None): + if albumentations is None: + raise RuntimeError('albumentations is not installed') + else: + from albumentations import Compose as albu_Compose + + assert isinstance(transforms, list), 'transforms must be a list.' + if keymap is not None: + assert isinstance(keymap, dict), 'keymap must be None or a dict. ' + + self.transforms = transforms + + self.aug = albu_Compose( + [self.albu_builder(t) for t in self.transforms]) + + if not keymap: + self.keymap_to_albu = dict(img='image') + else: + self.keymap_to_albu = keymap + self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} + + def albu_builder(self, cfg: Dict): + """Import a module from albumentations. + + It inherits some of :func:`build_from_cfg` logic. + Args: + cfg (dict): Config dict. It should at least contain the key "type". + Returns: + obj: The constructed object. + """ + + assert isinstance(cfg, dict) and 'type' in cfg, 'each item in ' \ + "transforms must be a dict with keyword 'type'." + args = cfg.copy() + + obj_type = args.pop('type') + if mmengine.is_str(obj_type): + obj_cls = getattr(albumentations, obj_type) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(transform) + for transform in args['transforms'] + ] + + return obj_cls(**args) + + @staticmethod + def mapper(d, keymap): + """Dictionary mapper. + + Renames keys according to keymap provided. + Args: + d (dict): old dict + keymap (dict): {'old_key':'new_key'} + Returns: + dict: new dict. + """ + + updated_dict = {} + for k, v in zip(d.keys(), d.values()): + new_k = keymap.get(k, k) + updated_dict[new_k] = d[k] + return updated_dict + + def transform(self, results: Dict) -> Dict: + """Transform function to perform albumentations transforms. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Transformed results, 'img' and 'img_shape' keys are + updated in result dict. + """ + assert 'img' in results, 'No `img` field in the input.' + + # dict to albumentations format + results = self.mapper(results, self.keymap_to_albu) + results = self.aug(**results) + + # back to the original format + results = self.mapper(results, self.keymap_back) + results['img_shape'] = results['img'].shape[:2] + + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(transforms={repr(self.transforms)})' + return repr_str + + +@TRANSFORMS.register_module() +class SimMIMMaskGenerator(BaseTransform): + """Generate random block mask for each Image. + + **Added Keys**: + + - mask + + This module is used in SimMIM to generate masks. + + Args: + input_size (int): Size of input image. Defaults to 192. + mask_patch_size (int): Size of each block mask. Defaults to 32. + model_patch_size (int): Patch size of each token. Defaults to 4. + mask_ratio (float): The mask ratio of image. Defaults to 0.6. + """ + + def __init__(self, + input_size: int = 192, + mask_patch_size: int = 32, + model_patch_size: int = 4, + mask_ratio: float = 0.6): + self.input_size = input_size + self.mask_patch_size = mask_patch_size + self.model_patch_size = model_patch_size + self.mask_ratio = mask_ratio + + assert self.input_size % self.mask_patch_size == 0 + assert self.mask_patch_size % self.model_patch_size == 0 + + self.rand_size = self.input_size // self.mask_patch_size + self.scale = self.mask_patch_size // self.model_patch_size + + self.token_count = self.rand_size**2 + self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) + + def transform(self, results: dict) -> dict: + """Method to generate random block mask for each Image in SimMIM. + + Args: + results (dict): Result dict from previous pipeline. + + Returns: + dict: Result dict with added key ``mask``. + """ + mask_idx = np.random.permutation(self.token_count)[:self.mask_count] + mask = np.zeros(self.token_count, dtype=int) + mask[mask_idx] = 1 + + mask = mask.reshape((self.rand_size, self.rand_size)) + mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) + + results.update({'mask': mask}) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(input_size={self.input_size}, ' + repr_str += f'mask_patch_size={self.mask_patch_size}, ' + repr_str += f'model_patch_size={self.model_patch_size}, ' + repr_str += f'mask_ratio={self.mask_ratio})' + return repr_str + + +@TRANSFORMS.register_module() +class BEiTMaskGenerator(BaseTransform): + """Generate mask for image. + + **Added Keys**: + + - mask + + This module is borrowed from + https://github.com/microsoft/unilm/tree/master/beit + + Args: + input_size (int): The size of input image. + num_masking_patches (int): The number of patches to be masked. + min_num_patches (int): The minimum number of patches to be masked + in the process of generating mask. Defaults to 4. + max_num_patches (int, optional): The maximum number of patches to be + masked in the process of generating mask. Defaults to None. + min_aspect (float): The minimum aspect ratio of mask blocks. Defaults + to 0.3. + min_aspect (float, optional): The minimum aspect ratio of mask blocks. + Defaults to None. + """ + + def __init__(self, + input_size: int, + num_masking_patches: int, + min_num_patches: int = 4, + max_num_patches: Optional[int] = None, + min_aspect: float = 0.3, + max_aspect: Optional[float] = None) -> None: + if not isinstance(input_size, tuple): + input_size = (input_size, ) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + + self.num_masking_patches = num_masking_patches + self.min_num_patches = min_num_patches + self.max_num_patches = num_masking_patches if max_num_patches is None \ + else max_num_patches + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def _mask(self, mask: np.ndarray, max_mask_patches: int) -> int: + """Generate mask recursively. + + Args: + mask (np.ndarray): The mask to be generated. + max_mask_patches (int): The maximum number of patches to be masked. + + Returns: + int: The number of patches masked. + """ + delta = 0 + for _ in range(10): + target_area = np.random.uniform(self.min_num_patches, + max_mask_patches) + aspect_ratio = math.exp(np.random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = np.random.randint(0, self.height - h) + left = np.random.randint(0, self.width - w) + + num_masked = mask[top:top + h, left:left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + if delta > 0: + break + return delta + + def transform(self, results: dict) -> dict: + """Method to generate random block mask for each Image in BEiT. + + Args: + results (dict): Result dict from previous pipeline. + + Returns: + dict: Result dict with added key ``mask``. + """ + mask = np.zeros(shape=(self.height, self.width), dtype=int) + + mask_count = 0 + while mask_count != self.num_masking_patches: + max_mask_patches = self.num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._mask(mask, max_mask_patches) + mask_count += delta + results.update({'mask': mask}) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(height={self.height}, ' + repr_str += f'width={self.width}, ' + repr_str += f'num_patches={self.num_patches}, ' + repr_str += f'num_masking_patches={self.num_masking_patches}, ' + repr_str += f'min_num_patches={self.min_num_patches}, ' + repr_str += f'max_num_patches={self.max_num_patches}, ' + repr_str += f'log_aspect_ratio={self.log_aspect_ratio})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomResizedCropAndInterpolationWithTwoPic(BaseTransform): + """Crop the given PIL Image to random size and aspect ratio with random + interpolation. + + **Required Keys**: + + - img + + **Modified Keys**: + + - img + + **Added Keys**: + + - target_img + + This module is borrowed from + https://github.com/microsoft/unilm/tree/master/beit. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a + random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio + is made. This crop is finally resized to given size. This is popularly used + to train the Inception networks. This module first crops the image and + resizes the crop to two different sizes. + + Args: + size (Union[tuple, int]): Expected output size of each edge of the + first image. + second_size (Union[tuple, int], optional): Expected output size of each + edge of the second image. + scale (tuple[float, float]): Range of size of the origin size cropped. + Defaults to (0.08, 1.0). + ratio (tuple[float, float]): Range of aspect ratio of the origin aspect + ratio cropped. Defaults to (3./4., 4./3.). + interpolation (str): The interpolation for the first image. Defaults + to ``bilinear``. + second_interpolation (str): The interpolation for the second image. + Defaults to ``lanczos``. + """ + + def __init__(self, + size: Union[tuple, int], + second_size=None, + scale=(0.08, 1.0), + ratio=(3. / 4., 4. / 3.), + interpolation='bilinear', + second_interpolation='lanczos') -> None: + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if second_size is not None: + if isinstance(second_size, tuple): + self.second_size = second_size + else: + self.second_size = (second_size, second_size) + else: + self.second_size = None + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + ('range should be of kind (min, max)') + + if interpolation == 'random': + self.interpolation = ('bilinear', 'bicubic') + else: + self.interpolation = interpolation + self.second_interpolation = second_interpolation + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img: np.ndarray, scale: tuple, + ratio: tuple) -> Sequence[int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (np.ndarray): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect + ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + img_h, img_w = img.shape[:2] + area = img_h * img_w + + for _ in range(10): + target_area = np.random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(np.random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w < img_w and h < img_h: + i = np.random.randint(0, img_h - h) + j = np.random.randint(0, img_w - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img_w / img_h + if in_ratio < min(ratio): + w = img_w + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img_h + w = int(round(h * max(ratio))) + else: # whole image + w = img_w + h = img_h + i = (img_h - h) // 2 + j = (img_w - w) // 2 + return i, j, h, w + + def transform(self, results: dict) -> dict: + """Crop the given image and resize it to two different sizes. + + This module crops the given image randomly and resize the crop to two + different sizes. This is popularly used in BEiT-style masked image + modeling, where an off-the-shelf model is used to provide the target. + + Args: + results (dict): Results from previous pipeline. + + Returns: + dict: Results after applying this transformation. + """ + img = results['img'] + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = np.random.choice(self.interpolation) + else: + interpolation = self.interpolation + if self.second_size is None: + img = img[i:i + h, j:j + w] + img = mmcv.imresize(img, self.size, interpolation=interpolation) + results.update({'img': img}) + else: + img = img[i:i + h, j:j + w] + img_sample = mmcv.imresize( + img, self.size, interpolation=interpolation) + img_target = mmcv.imresize( + img, self.second_size, interpolation=self.second_interpolation) + results.update({'img': [img_sample, img_target]}) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(size={self.size}, ' + repr_str += f'second_size={self.second_size}, ' + repr_str += f'interpolation={self.interpolation}, ' + repr_str += f'second_interpolation={self.second_interpolation}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'ratio={self.ratio})' + return repr_str + + +@TRANSFORMS.register_module() +class CleanCaption(BaseTransform): + """Clean caption text. + + Remove some useless punctuation for the caption task. + + **Required Keys:** + + - ``*keys`` + + **Modified Keys:** + + - ``*keys`` + + Args: + keys (Sequence[str], optional): The keys of text to be cleaned. + Defaults to 'gt_caption'. + remove_chars (str): The characters to be removed. Defaults to + :py:attr:`string.punctuation`. + lowercase (bool): Whether to convert the text to lowercase. + Defaults to True. + remove_dup_space (bool): Whether to remove duplicated whitespaces. + Defaults to True. + strip (bool): Whether to remove leading and trailing whitespaces. + Defaults to True. + """ + + def __init__( + self, + keys='gt_caption', + remove_chars=string.punctuation, + lowercase=True, + remove_dup_space=True, + strip=True, + ): + if isinstance(keys, str): + keys = [keys] + self.keys = keys + self.transtab = str.maketrans({ch: None for ch in remove_chars}) + self.lowercase = lowercase + self.remove_dup_space = remove_dup_space + self.strip = strip + + def _clean(self, text): + """Perform text cleaning before tokenizer.""" + + if self.strip: + text = text.strip() + + text = text.translate(self.transtab) + + if self.remove_dup_space: + text = re.sub(r'\s{2,}', ' ', text) + + if self.lowercase: + text = text.lower() + + return text + + def clean(self, text): + """Perform text cleaning before tokenizer.""" + if isinstance(text, (list, tuple)): + return [self._clean(item) for item in text] + elif isinstance(text, str): + return self._clean(text) + else: + raise TypeError('text must be a string or a list of strings') + + def transform(self, results: dict) -> dict: + """Method to clean the input text data.""" + for key in self.keys: + results[key] = self.clean(results[key]) + return results + + +@TRANSFORMS.register_module() +class OFAAddObjects(BaseTransform): + + def transform(self, results: dict) -> dict: + if 'objects' not in results: + raise ValueError( + 'Some OFA fine-tuned models requires `objects` field in the ' + 'dataset, which is generated by VinVL. Or please use ' + 'zero-shot configs. See ' + 'https://github.com/OFA-Sys/OFA/issues/189') + + if 'question' in results: + prompt = '{} object: {}'.format( + results['question'], + ' '.join(results['objects']), + ) + results['decoder_prompt'] = prompt + results['question'] = prompt + + +@TRANSFORMS.register_module() +class RandomTranslatePad(BaseTransform): + + def __init__(self, size=640, aug_translate=False): + self.size = size + self.aug_translate = aug_translate + + @cache_randomness + def rand_translate_params(self, dh, dw): + top = np.random.randint(0, dh) + left = np.random.randint(0, dw) + return top, left + + def transform(self, results: dict) -> dict: + img = results['img'] + h, w = img.shape[:-1] + dw = self.size - w + dh = self.size - h + if self.aug_translate: + top, left = self.rand_translate_params(dh, dw) + else: + top = round(dh / 2.0 - 0.1) + left = round(dw / 2.0 - 0.1) + + out_img = np.zeros((self.size, self.size, 3), dtype=np.float32) + out_img[top:top + h, left:left + w, :] = img + results['img'] = out_img + results['img_shape'] = (self.size, self.size) + + # translate box + if 'gt_bboxes' in results.keys(): + for i in range(len(results['gt_bboxes'])): + box = results['gt_bboxes'][i] + box[0], box[2] = box[0] + left, box[2] + left + box[1], box[3] = box[1] + top, box[3] + top + results['gt_bboxes'][i] = box + + return results + + +@TRANSFORMS.register_module() +class MAERandomResizedCrop(transforms.RandomResizedCrop): + """RandomResizedCrop for matching TF/TPU implementation: no for-loop is + used. + + This may lead to results different with torchvision's version. + Following BYOL's TF code: + https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 # noqa: E501 + """ + + @staticmethod + def get_params(img: Image.Image, scale: tuple, ratio: tuple) -> Tuple: + width, height = img.size + area = height * width + + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + log_ratio = torch.log(torch.tensor(ratio)) + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + w = min(w, width) + h = min(h, height) + + i = torch.randint(0, height - h + 1, size=(1, )).item() + j = torch.randint(0, width - w + 1, size=(1, )).item() + + return i, j, h, w + + def forward(self, results: dict) -> dict: + """The forward function of MAERandomResizedCrop. + + Args: + results (dict): The results dict contains the image and all these + information related to the image. + + Returns: + dict: The results dict contains the cropped image and all these + information related to the image. + """ + img = results['img'] + i, j, h, w = self.get_params(img, self.scale, self.ratio) + img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation) + results['img'] = img + return results diff --git a/mmpretrain/datasets/transforms/utils.py b/mmpretrain/datasets/transforms/utils.py new file mode 100644 index 0000000..d794048 --- /dev/null +++ b/mmpretrain/datasets/transforms/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Union + +from mmcv.transforms import BaseTransform + +PIPELINE_TYPE = List[Union[dict, BaseTransform]] + + +def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int: + """Returns the index of the transform in a pipeline. + + Args: + pipeline (List[dict] | List[BaseTransform]): The transforms list. + target (str): The target transform class name. + + Returns: + int: The transform index. Returns -1 if not found. + """ + for i, transform in enumerate(pipeline): + if isinstance(transform, dict): + if isinstance(transform['type'], type): + if transform['type'].__name__ == target: + return i + else: + if transform['type'] == target: + return i + else: + if transform.__class__.__name__ == target: + return i + + return -1 + + +def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False): + """Remove the target transform type from the pipeline. + + Args: + pipeline (List[dict] | List[BaseTransform]): The transforms list. + target (str): The target transform class name. + inplace (bool): Whether to modify the pipeline inplace. + + Returns: + The modified transform. + """ + idx = get_transform_idx(pipeline, target) + if not inplace: + pipeline = copy.deepcopy(pipeline) + while idx >= 0: + pipeline.pop(idx) + idx = get_transform_idx(pipeline, target) + + return pipeline diff --git a/mmpretrain/datasets/transforms/wrappers.py b/mmpretrain/datasets/transforms/wrappers.py new file mode 100644 index 0000000..c0dfd73 --- /dev/null +++ b/mmpretrain/datasets/transforms/wrappers.py @@ -0,0 +1,144 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Callable, List, Union + +from mmcv.transforms import BaseTransform, Compose + +from mmpretrain.registry import TRANSFORMS + +# Define type of transform or transform config +Transform = Union[dict, Callable[[dict], dict]] + + +@TRANSFORMS.register_module() +class MultiView(BaseTransform): + """A transform wrapper for multiple views of an image. + + Args: + transforms (list[dict | callable], optional): Sequence of transform + object or config dict to be wrapped. + mapping (dict): A dict that defines the input key mapping. + The keys corresponds to the inner key (i.e., kwargs of the + ``transform`` method), and should be string type. The values + corresponds to the outer keys (i.e., the keys of the + data/results), and should have a type of string, list or dict. + None means not applying input mapping. Default: None. + allow_nonexist_keys (bool): If False, the outer keys in the mapping + must exist in the input data, or an exception will be raised. + Default: False. + + Examples: + >>> # Example 1: MultiViews 1 pipeline with 2 views + >>> pipeline = [ + >>> dict(type='MultiView', + >>> num_views=2, + >>> transforms=[ + >>> [ + >>> dict(type='Resize', scale=224))], + >>> ]) + >>> ] + >>> # Example 2: MultiViews 2 pipelines, the first with 2 views, + >>> # the second with 6 views + >>> pipeline = [ + >>> dict(type='MultiView', + >>> num_views=[2, 6], + >>> transforms=[ + >>> [ + >>> dict(type='Resize', scale=224)], + >>> [ + >>> dict(type='Resize', scale=224), + >>> dict(type='RandomSolarize')], + >>> ]) + >>> ] + """ + + def __init__(self, transforms: List[List[Transform]], + num_views: Union[int, List[int]]) -> None: + + if isinstance(num_views, int): + num_views = [num_views] + assert isinstance(num_views, List) + assert len(num_views) == len(transforms) + self.num_views = num_views + + self.pipelines = [] + for trans in transforms: + pipeline = Compose(trans) + self.pipelines.append(pipeline) + + self.transforms = [] + for i in range(len(num_views)): + self.transforms.extend([self.pipelines[i]] * num_views[i]) + + def transform(self, results: dict) -> dict: + """Apply transformation to inputs. + + Args: + results (dict): Result dict from previous pipelines. + + Returns: + dict: Transformed results. + """ + multi_views_outputs = dict(img=[]) + for trans in self.transforms: + inputs = copy.deepcopy(results) + outputs = trans(inputs) + + multi_views_outputs['img'].append(outputs['img']) + results.update(multi_views_outputs) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + '(' + for i, p in enumerate(self.pipelines): + repr_str += f'\nPipeline {i + 1} with {self.num_views[i]} views:\n' + repr_str += str(p) + repr_str += ')' + return repr_str + + +@TRANSFORMS.register_module() +class ApplyToList(BaseTransform): + """A transform wrapper to apply the wrapped transforms to a list of items. + For example, to load and resize a list of images. + + Args: + transforms (list[dict | callable]): Sequence of transform config dict + to be wrapped. + scatter_key (str): The key to scatter data dict. If the field is a + list, scatter the list to multiple data dicts to do transformation. + collate_keys (List[str]): The keys to collate from multiple data dicts. + The fields in ``collate_keys`` will be composed into a list after + transformation, and the other fields will be adopted from the + first data dict. + """ + + def __init__(self, transforms, scatter_key, collate_keys): + super().__init__() + + self.transforms = Compose([TRANSFORMS.build(t) for t in transforms]) + self.scatter_key = scatter_key + self.collate_keys = set(collate_keys) + self.collate_keys.add(self.scatter_key) + + def transform(self, results: dict): + scatter_field = results.get(self.scatter_key) + + if isinstance(scatter_field, list): + scattered_results = [] + for item in scatter_field: + single_results = copy.deepcopy(results) + single_results[self.scatter_key] = item + scattered_results.append(self.transforms(single_results)) + + final_output = scattered_results[0] + + # merge output list to single output + for key in scattered_results[0].keys(): + if key in self.collate_keys: + final_output[key] = [ + single[key] for single in scattered_results + ] + return final_output + else: + return self.transforms(results) diff --git a/mmpretrain/datasets/utils.py b/mmpretrain/datasets/utils.py new file mode 100644 index 0000000..fcb60e4 --- /dev/null +++ b/mmpretrain/datasets/utils.py @@ -0,0 +1,243 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import gzip +import hashlib +import os +import os.path +import shutil +import tarfile +import tempfile +import urllib.error +import urllib.request +import zipfile + +from mmengine.fileio import LocalBackend, get_file_backend + +__all__ = [ + 'rm_suffix', 'check_integrity', 'download_and_extract_archive', + 'open_maybe_compressed_file' +] + + +def rm_suffix(s, suffix=None): + if suffix is None: + return s[:s.rfind('.')] + else: + return s[:s.rfind(suffix)] + + +def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024): + md5 = hashlib.md5() + backend = get_file_backend(fpath, enable_singleton=True) + if isinstance(backend, LocalBackend): + # Enable chunk update for local file. + with open(fpath, 'rb') as f: + for chunk in iter(lambda: f.read(chunk_size), b''): + md5.update(chunk) + else: + md5.update(backend.get(fpath)) + return md5.hexdigest() + + +def check_md5(fpath, md5, **kwargs): + return md5 == calculate_md5(fpath, **kwargs) + + +def check_integrity(fpath, md5=None): + if not os.path.isfile(fpath): + return False + if md5 is None: + return True + return check_md5(fpath, md5) + + +def download_url_to_file(url, dst, hash_prefix=None, progress=True): + """Download object at the given URL to a local path. + + Modified from + https://pytorch.org/docs/stable/hub.html#torch.hub.download_url_to_file + + Args: + url (str): URL of the object to download + dst (str): Full path where object will be saved, + e.g. ``/tmp/temporary_file`` + hash_prefix (string, optional): If not None, the SHA256 downloaded + file should start with ``hash_prefix``. Defaults to None. + progress (bool): whether or not to display a progress bar to stderr. + Defaults to True + """ + file_size = None + req = urllib.request.Request(url) + u = urllib.request.urlopen(req) + meta = u.info() + if hasattr(meta, 'getheaders'): + content_length = meta.getheaders('Content-Length') + else: + content_length = meta.get_all('Content-Length') + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + # We deliberately save it in a temp file and move it after download is + # complete. This prevents a local file being overridden by a broken + # download. + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + + import rich.progress + columns = [ + rich.progress.DownloadColumn(), + rich.progress.BarColumn(bar_width=None), + rich.progress.TimeRemainingColumn(), + ] + try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + with rich.progress.Progress(*columns) as pbar: + task = pbar.add_task('download', total=file_size, visible=progress) + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + if hash_prefix is not None: + sha256.update(buffer) + pbar.update(task, advance=len(buffer)) + + f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() + if digest[:len(hash_prefix)] != hash_prefix: + raise RuntimeError( + 'invalid hash value (expected "{}", got "{}")'.format( + hash_prefix, digest)) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +def download_url(url, root, filename=None, md5=None): + """Download a file from a url and place it in root. + + Args: + url (str): URL to download file from. + root (str): Directory to place downloaded file in. + filename (str | None): Name to save the file under. + If filename is None, use the basename of the URL. + md5 (str | None): MD5 checksum of the download. + If md5 is None, download without md5 check. + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if check_integrity(fpath, md5): + print(f'Using downloaded and verified file: {fpath}') + else: + try: + print(f'Downloading {url} to {fpath}') + download_url_to_file(url, fpath) + except (urllib.error.URLError, IOError) as e: + if url[:5] == 'https': + url = url.replace('https:', 'http:') + print('Failed download. Trying https -> http instead.' + f' Downloading {url} to {fpath}') + download_url_to_file(url, fpath) + else: + raise e + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError('File not found or corrupted.') + + +def _is_tarxz(filename): + return filename.endswith('.tar.xz') + + +def _is_tar(filename): + return filename.endswith('.tar') + + +def _is_targz(filename): + return filename.endswith('.tar.gz') + + +def _is_tgz(filename): + return filename.endswith('.tgz') + + +def _is_gzip(filename): + return filename.endswith('.gz') and not filename.endswith('.tar.gz') + + +def _is_zip(filename): + return filename.endswith('.zip') + + +def extract_archive(from_path, to_path=None, remove_finished=False): + if to_path is None: + to_path = os.path.dirname(from_path) + + if _is_tar(from_path): + with tarfile.open(from_path, 'r') as tar: + tar.extractall(path=to_path) + elif _is_targz(from_path) or _is_tgz(from_path): + with tarfile.open(from_path, 'r:gz') as tar: + tar.extractall(path=to_path) + elif _is_tarxz(from_path): + with tarfile.open(from_path, 'r:xz') as tar: + tar.extractall(path=to_path) + elif _is_gzip(from_path): + to_path = os.path.join( + to_path, + os.path.splitext(os.path.basename(from_path))[0]) + with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f: + out_f.write(zip_f.read()) + elif _is_zip(from_path): + with zipfile.ZipFile(from_path, 'r') as z: + z.extractall(to_path) + else: + raise ValueError(f'Extraction of {from_path} not supported') + + if remove_finished: + os.remove(from_path) + + +def download_and_extract_archive(url, + download_root, + extract_root=None, + filename=None, + md5=None, + remove_finished=False): + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print(f'Extracting {archive} to {extract_root}') + extract_archive(archive, extract_root, remove_finished) + + +def open_maybe_compressed_file(path: str): + """Return a file object that possibly decompresses 'path' on the fly. + + Decompression occurs when argument `path` is a string and ends with '.gz' + or '.xz'. + """ + if not isinstance(path, str): + return path + if path.endswith('.gz'): + import gzip + return gzip.open(path, 'rb') + if path.endswith('.xz'): + import lzma + return lzma.open(path, 'rb') + return open(path, 'rb') diff --git a/mmpretrain/datasets/vg_vqa.py b/mmpretrain/datasets/vg_vqa.py new file mode 100644 index 0000000..2d83884 --- /dev/null +++ b/mmpretrain/datasets/vg_vqa.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine.fileio import load + +from mmpretrain.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class VGVQA(BaseDataset): + """Visual Genome VQA dataset.""" + + def load_data_list(self) -> List[dict]: + """Load data list. + + Compare to BaseDataset, the only difference is that coco_vqa annotation + file is already a list of data. There is no 'metainfo'. + """ + + raw_data_list = load(self.ann_file) + if not isinstance(raw_data_list, list): + raise TypeError( + f'The VQA annotations loaded from annotation file ' + f'should be a dict, but got {type(raw_data_list)}!') + + # load and parse data_infos. + data_list = [] + for raw_data_info in raw_data_list: + # parse raw data information to target format + data_info = self.parse_data_info(raw_data_info) + if isinstance(data_info, dict): + # For VQA tasks, each `data_info` looks like: + # { + # "question_id": 986769, + # "question": "How many people are there?", + # "answer": "two", + # "image": "image/1.jpg", + # "dataset": "vg" + # } + + # change 'image' key to 'img_path' + # TODO: This process will be removed, after the annotation file + # is preprocess. + data_info['img_path'] = data_info['image'] + del data_info['image'] + + if 'answer' in data_info: + # add answer_weight & answer_count, delete duplicate answer + if data_info['dataset'] == 'vqa': + answer_weight = {} + for answer in data_info['answer']: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len( + data_info['answer']) + else: + answer_weight[answer] = 1 / len( + data_info['answer']) + + data_info['answer'] = list(answer_weight.keys()) + data_info['answer_weight'] = list( + answer_weight.values()) + data_info['answer_count'] = len(answer_weight) + + elif data_info['dataset'] == 'vg': + data_info['answers'] = [data_info['answer']] + data_info['answer_weight'] = [0.2] + data_info['answer_count'] = 1 + + data_list.append(data_info) + + else: + raise TypeError( + f'Each VQA data element loaded from annotation file ' + f'should be a dict, but got {type(data_info)}!') + + return data_list diff --git a/mmpretrain/datasets/visual_genome.py b/mmpretrain/datasets/visual_genome.py new file mode 100644 index 0000000..8c33b86 --- /dev/null +++ b/mmpretrain/datasets/visual_genome.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from itertools import chain +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VisualGenomeQA(BaseDataset): + """Visual Genome Question Answering dataset. + + dataset structure: :: + + data_root + ├── image + │   ├── 1.jpg + │   ├── 2.jpg + │   └── ... + └── question_answers.json + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. Defaults to ``"image"``. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to ``"question_answers.json"``. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str = 'image', + ann_file: str = 'question_answers.json', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def _create_image_index(self): + img_prefix = self.data_prefix['img_path'] + + files = mmengine.list_dir_or_file(img_prefix, list_dir=False) + image_index = {} + for file in files: + image_id = re.findall(r'\d+', file) + if len(image_id) > 0: + image_id = int(image_id[-1]) + image_index[image_id] = mmengine.join_path(img_prefix, file) + + return image_index + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + # The original Visual Genome annotation file and question file includes + # only image id but no image file paths. + self.image_index = self._create_image_index() + + data_list = [] + for qas in chain.from_iterable(ann['qas'] for ann in annotations): + # ann example + # { + # 'id': 1, + # 'qas': [ + # { + # 'a_objects': [], + # 'question': 'What color is the clock?', + # 'image_id': 1, + # 'qa_id': 986768, + # 'answer': 'Two.', + # 'q_objects': [], + # } + # ... + # ] + # } + + data_info = { + 'img_path': self.image_index[qas['image_id']], + 'quesiton': qas['quesiton'], + 'question_id': qas['question_id'], + 'image_id': qas['image_id'], + 'gt_answer': [qas['answer']], + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/vizwiz.py b/mmpretrain/datasets/vizwiz.py new file mode 100644 index 0000000..7b5dd39 --- /dev/null +++ b/mmpretrain/datasets/vizwiz.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import Counter +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VizWiz(BaseDataset): + """VizWiz dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + data_list = [] + for ann in annotations: + # { + # "image": "VizWiz_val_00000001.jpg", + # "question": "Can you tell me what this medicine is please?", + # "answers": [ + # { + # "answer": "no", + # "answer_confidence": "yes" + # }, + # { + # "answer": "unanswerable", + # "answer_confidence": "yes" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "unanswerable", + # "answer_confidence": "yes" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time cold medicine", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time", + # "answer_confidence": "maybe" + # }, + # { + # "answer": "night time medicine", + # "answer_confidence": "yes" + # } + # ], + # "answer_type": "other", + # "answerable": 1 + # }, + data_info = dict() + data_info['question'] = ann['question'] + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], ann['image']) + + if 'answerable' not in ann: + data_list.append(data_info) + else: + if ann['answerable'] == 1: + # add answer_weight & answer_count, delete duplicate answer + answers = [] + for item in ann.pop('answers'): + if item['answer_confidence'] == 'yes' and item[ + 'answer'] != 'unanswerable': + answers.append(item['answer']) + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + data_info['gt_answer'] = list(count.keys()) + data_info['gt_answer_weight'] = answer_weight + # data_info.update(ann) + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/datasets/voc.py b/mmpretrain/datasets/voc.py new file mode 100644 index 0000000..39544de --- /dev/null +++ b/mmpretrain/datasets/voc.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import xml.etree.ElementTree as ET +from typing import List, Optional, Union + +from mmengine import get_file_backend, list_from_file +from mmengine.logging import MMLogger + +from mmpretrain.registry import DATASETS +from .base_dataset import expanduser +from .categories import VOC2007_CATEGORIES +from .multi_label import MultiLabelDataset + + +@DATASETS.register_module() +class VOC(MultiLabelDataset): + """`Pascal VOC `_ Dataset. + + After decompression, the dataset directory structure is as follows: + + VOC dataset directory: :: + + VOC2007 + ├── JPEGImages + │ ├── xxx.jpg + │ ├── xxy.jpg + │ └── ... + ├── Annotations + │ ├── xxx.xml + │ ├── xxy.xml + │ └── ... + └── ImageSets + └── Main + ├── train.txt + ├── val.txt + ├── trainval.txt + ├── test.txt + └── ... + + Extra difficult label is in VOC annotations, we will use + `gt_label_difficult` to record the difficult labels in each sample + and corresponding evaluation should take care of this field + to calculate metrics. Usually, difficult labels are reckoned as + negative in defaults. + + Args: + data_root (str): The root directory for VOC dataset. + split (str, optional): The dataset split, supports "train", + "val", "trainval", and "test". Default to "trainval". + image_set_path (str, optional): The path of image set, The file which + lists image ids of the sub dataset, and this path is relative + to ``data_root``. Default to ''. + data_prefix (dict): Prefix for data and annotation, keyword + 'img_path' and 'ann_path' can be set. Defaults to be + ``dict(img_path='JPEGImages', ann_path='Annotations')``. + metainfo (dict, optional): Meta information for dataset, such as + categories information. Defaults to None. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import VOC + >>> train_dataset = VOC(data_root='data/VOC2007', split='trainval') + >>> train_dataset + Dataset VOC + Number of samples: 5011 + Number of categories: 20 + Prefix of dataset: data/VOC2007 + Path of image set: data/VOC2007/ImageSets/Main/trainval.txt + Prefix of images: data/VOC2007/JPEGImages + Prefix of annotations: data/VOC2007/Annotations + >>> test_dataset = VOC(data_root='data/VOC2007', split='test') + >>> test_dataset + Dataset VOC + Number of samples: 4952 + Number of categories: 20 + Prefix of dataset: data/VOC2007 + Path of image set: data/VOC2007/ImageSets/Main/test.txt + Prefix of images: data/VOC2007/JPEGImages + Prefix of annotations: data/VOC2007/Annotations + """ # noqa: E501 + + METAINFO = {'classes': VOC2007_CATEGORIES} + + def __init__(self, + data_root: str, + split: str = 'trainval', + image_set_path: str = '', + data_prefix: Union[str, dict] = dict( + img_path='JPEGImages', ann_path='Annotations'), + test_mode: bool = False, + metainfo: Optional[dict] = None, + **kwargs): + + self.backend = get_file_backend(data_root, enable_singleton=True) + + if split: + splits = ['train', 'val', 'trainval', 'test'] + assert split in splits, \ + f"The split must be one of {splits}, but get '{split}'" + self.split = split + + if not data_prefix: + data_prefix = dict( + img_path='JPEGImages', ann_path='Annotations') + if not image_set_path: + image_set_path = self.backend.join_path( + 'ImageSets', 'Main', f'{split}.txt') + + # To handle the BC-breaking + if (split == 'train' or split == 'trainval') and test_mode: + logger = MMLogger.get_current_instance() + logger.warning(f'split="{split}" but test_mode=True. ' + f'The {split} set will be used.') + + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + assert isinstance(data_prefix, dict) and 'img_path' in data_prefix, \ + '`data_prefix` must be a dict with key img_path' + + if (split and split not in ['val', 'test']) or not test_mode: + assert 'ann_path' in data_prefix and data_prefix[ + 'ann_path'] is not None, \ + '"ann_path" must be set in `data_prefix`' \ + 'when validation or test set is used.' + + self.data_root = data_root + self.image_set_path = self.backend.join_path(data_root, image_set_path) + + super().__init__( + ann_file='', + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + @property + def ann_prefix(self): + """The prefix of images.""" + if 'ann_path' in self.data_prefix: + return self.data_prefix['ann_path'] + else: + return None + + def _get_labels_from_xml(self, img_id): + """Get gt_labels and labels_difficult from xml file.""" + xml_path = self.backend.join_path(self.ann_prefix, f'{img_id}.xml') + content = self.backend.get(xml_path) + root = ET.fromstring(content) + + labels, labels_difficult = set(), set() + for obj in root.findall('object'): + label_name = obj.find('name').text + # in case customized dataset has wrong labels + # or CLASSES has been override. + if label_name not in self.CLASSES: + continue + label = self.class_to_idx[label_name] + difficult = int(obj.find('difficult').text) + if difficult: + labels_difficult.add(label) + else: + labels.add(label) + + return list(labels), list(labels_difficult) + + def load_data_list(self): + """Load images and ground truth labels.""" + data_list = [] + img_ids = list_from_file(self.image_set_path) + + for img_id in img_ids: + img_path = self.backend.join_path(self.img_prefix, f'{img_id}.jpg') + + labels, labels_difficult = None, None + if self.ann_prefix is not None: + labels, labels_difficult = self._get_labels_from_xml(img_id) + + info = dict( + img_path=img_path, + gt_label=labels, + gt_label_difficult=labels_difficult) + data_list.append(info) + + return data_list + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [ + f'Prefix of dataset: \t{self.data_root}', + f'Path of image set: \t{self.image_set_path}', + f'Prefix of images: \t{self.img_prefix}', + f'Prefix of annotations: \t{self.ann_prefix}' + ] + + return body diff --git a/mmpretrain/datasets/vsr.py b/mmpretrain/datasets/vsr.py new file mode 100644 index 0000000..7b10959 --- /dev/null +++ b/mmpretrain/datasets/vsr.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VSR(BaseDataset): + """VSR: Visual Spatial Reasoning dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file) + + data_list = [] + for ann in annotations: + # ann example + # { + # "image": "train2017/000000372029.jpg", + # "question": "The dog is on the surfboard.", + # "answer": true + # } + data_info = dict() + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], ann['image']) + data_info['question'] = ann['question'] + data_info['gt_answer'] = 'yes' if ann['answer'] else 'no' + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/engine/__init__.py b/mmpretrain/engine/__init__.py new file mode 100644 index 0000000..332fea0 --- /dev/null +++ b/mmpretrain/engine/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hooks import * # noqa: F401, F403 +from .optimizers import * # noqa: F401, F403 +from .runners import * # noqa: F401, F403 +from .schedulers import * # noqa: F401, F403 diff --git a/mmpretrain/engine/hooks/__init__.py b/mmpretrain/engine/hooks/__init__.py new file mode 100644 index 0000000..bc9e22b --- /dev/null +++ b/mmpretrain/engine/hooks/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .class_num_check_hook import ClassNumCheckHook +from .densecl_hook import DenseCLHook +from .ema_hook import EMAHook +from .margin_head_hooks import SetAdaptiveMarginsHook +from .precise_bn_hook import PreciseBNHook +from .retriever_hooks import PrepareProtoBeforeValLoopHook +from .simsiam_hook import SimSiamHook +from .swav_hook import SwAVHook +from .switch_recipe_hook import SwitchRecipeHook +from .visualization_hook import VisualizationHook +from .warmup_param_hook import WarmupParamHook + +__all__ = [ + 'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook', + 'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook', + 'SetAdaptiveMarginsHook', 'EMAHook', 'SimSiamHook', 'DenseCLHook', + 'SwAVHook', 'WarmupParamHook' +] diff --git a/mmpretrain/engine/hooks/class_num_check_hook.py b/mmpretrain/engine/hooks/class_num_check_hook.py new file mode 100644 index 0000000..38170d6 --- /dev/null +++ b/mmpretrain/engine/hooks/class_num_check_hook.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved +from mmengine.hooks import Hook +from mmengine.utils import is_seq_of + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class ClassNumCheckHook(Hook): + """Class Number Check HOOK.""" + + def _check_head(self, runner, dataset): + """Check whether the `num_classes` in head matches the length of + `CLASSES` in `dataset`. + + Args: + runner (obj:`Runner`): runner object. + dataset (obj: `BaseDataset`): the dataset to check. + """ + model = runner.model + if dataset.CLASSES is None: + runner.logger.warning( + f'Please set class information in `metainfo` ' + f'in the {dataset.__class__.__name__} and' + f'check if it is consistent with the `num_classes` ' + f'of head') + else: + assert is_seq_of(dataset.CLASSES, str), \ + (f'Class information in `metainfo` in ' + f'{dataset.__class__.__name__} should be a tuple of str.') + for _, module in model.named_modules(): + if hasattr(module, 'num_classes'): + assert module.num_classes == len(dataset.CLASSES), \ + (f'The `num_classes` ({module.num_classes}) in ' + f'{module.__class__.__name__} of ' + f'{model.__class__.__name__} does not matches ' + f'the length of class information in `metainfo` ' + f'{len(dataset.CLASSES)}) in ' + f'{dataset.__class__.__name__}') + + def before_train(self, runner): + """Check whether the training dataset is compatible with head. + + Args: + runner (obj: `IterBasedRunner`): Iter based Runner. + """ + self._check_head(runner, runner.train_dataloader.dataset) + + def before_val(self, runner): + """Check whether the validation dataset is compatible with head. + + Args: + runner (obj:`IterBasedRunner`): Iter based Runner. + """ + self._check_head(runner, runner.val_dataloader.dataset) + + def before_test(self, runner): + """Check whether the test dataset is compatible with head. + + Args: + runner (obj:`IterBasedRunner`): Iter based Runner. + """ + self._check_head(runner, runner.test_dataloader.dataset) diff --git a/mmpretrain/engine/hooks/densecl_hook.py b/mmpretrain/engine/hooks/densecl_hook.py new file mode 100644 index 0000000..8c7e17d --- /dev/null +++ b/mmpretrain/engine/hooks/densecl_hook.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS +from mmpretrain.utils import get_ori_model + + +@HOOKS.register_module() +class DenseCLHook(Hook): + """Hook for DenseCL. + + This hook includes ``loss_lambda`` warmup in DenseCL. + Borrowed from the authors' code: ``_. + + Args: + start_iters (int): The number of warmup iterations to set + ``loss_lambda=0``. Defaults to 1000. + """ + + def __init__(self, start_iters: int = 1000) -> None: + self.start_iters = start_iters + + def before_train(self, runner) -> None: + """Obtain ``loss_lambda`` from algorithm.""" + assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \ + "The runner must have attribute \"loss_lambda\" in DenseCL." + self.loss_lambda = get_ori_model(runner.model).loss_lambda + + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: Optional[Sequence[dict]] = None) -> None: + """Adjust ``loss_lambda`` every train iter.""" + assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \ + "The runner must have attribute \"loss_lambda\" in DenseCL." + cur_iter = runner.iter + if cur_iter >= self.start_iters: + get_ori_model(runner.model).loss_lambda = self.loss_lambda + else: + get_ori_model(runner.model).loss_lambda = 0. diff --git a/mmpretrain/engine/hooks/ema_hook.py b/mmpretrain/engine/hooks/ema_hook.py new file mode 100644 index 0000000..284d211 --- /dev/null +++ b/mmpretrain/engine/hooks/ema_hook.py @@ -0,0 +1,216 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import itertools +import warnings +from typing import Dict, Optional + +from mmengine.hooks import EMAHook as BaseEMAHook +from mmengine.logging import MMLogger +from mmengine.runner import Runner + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class EMAHook(BaseEMAHook): + """A Hook to apply Exponential Moving Average (EMA) on the model during + training. + + Comparing with :class:`mmengine.hooks.EMAHook`, this hook accepts + ``evaluate_on_ema`` and ``evaluate_on_origin`` arguments. By default, the + ``evaluate_on_ema`` is enabled, and if you want to do validation and + testing on both original and EMA models, please set both arguments + ``True``. + + Note: + - EMAHook takes priority over CheckpointHook. + - The original model parameters are actually saved in ema field after + train. + - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time. + + Args: + ema_type (str): The type of EMA strategy to use. You can find the + supported strategies in :mod:`mmengine.model.averaged_model`. + Defaults to 'ExponentialMovingAverage'. + strict_load (bool): Whether to strictly enforce that the keys of + ``state_dict`` in checkpoint match the keys returned by + ``self.module.state_dict``. Defaults to False. + Changed in v0.3.0. + begin_iter (int): The number of iteration to enable ``EMAHook``. + Defaults to 0. + begin_epoch (int): The number of epoch to enable ``EMAHook``. + Defaults to 0. + evaluate_on_ema (bool): Whether to evaluate (validate and test) + on EMA model during val-loop and test-loop. Defaults to True. + evaluate_on_origin (bool): Whether to evaluate (validate and test) + on the original model during val-loop and test-loop. + Defaults to False. + **kwargs: Keyword arguments passed to subclasses of + :obj:`BaseAveragedModel` + """ + + priority = 'NORMAL' + + def __init__(self, + ema_type: str = 'ExponentialMovingAverage', + strict_load: bool = False, + begin_iter: int = 0, + begin_epoch: int = 0, + evaluate_on_ema: bool = True, + evaluate_on_origin: bool = False, + **kwargs): + super().__init__( + ema_type=ema_type, + strict_load=strict_load, + begin_iter=begin_iter, + begin_epoch=begin_epoch, + **kwargs) + + if not evaluate_on_ema and not evaluate_on_origin: + warnings.warn( + 'Automatically set `evaluate_on_origin=True` since the ' + '`evaluate_on_ema` is disabled. If you want to disable ' + 'all validation, please modify the `val_interval` of ' + 'the `train_cfg`.', UserWarning) + evaluate_on_origin = True + + self.evaluate_on_ema = evaluate_on_ema + self.evaluate_on_origin = evaluate_on_origin + self.load_ema_from_ckpt = False + + def before_train(self, runner) -> None: + super().before_train(runner) + if not runner._resume and self.load_ema_from_ckpt: + # If loaded EMA state dict but not want to resume training + # overwrite the EMA state dict with the source model. + MMLogger.get_current_instance().info( + 'Load from a checkpoint with EMA parameters but not ' + 'resume training. Initialize the model parameters with ' + 'EMA parameters') + for p_ema, p_src in zip(self._ema_params, self._src_params): + p_src.data.copy_(p_ema.data) + + def before_val_epoch(self, runner) -> None: + """We load parameter values from ema model to source model before + validation. + + Args: + runner (Runner): The runner of the training process. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + + def after_val_epoch(self, + runner, + metrics: Optional[Dict[str, float]] = None) -> None: + """We recover source model's parameter from ema model after validation. + + Args: + runner (Runner): The runner of the validation process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on validation dataset. The keys are the names of the + metrics, and the values are corresponding results. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + + if self.evaluate_on_ema and self.evaluate_on_origin: + # Re-evaluate if evaluate on both ema and origin. + val_loop = runner.val_loop + + runner.model.eval() + for idx, data_batch in enumerate(val_loop.dataloader): + val_loop.run_iter(idx, data_batch) + + # compute metrics + origin_metrics = val_loop.evaluator.evaluate( + len(val_loop.dataloader.dataset)) + + for k, v in origin_metrics.items(): + runner.message_hub.update_scalar(f'val/{k}_origin', v) + + def before_test_epoch(self, runner) -> None: + """We load parameter values from ema model to source model before test. + + Args: + runner (Runner): The runner of the training process. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + MMLogger.get_current_instance().info('Start testing on EMA model.') + else: + MMLogger.get_current_instance().info( + 'Start testing on the original model.') + + def after_test_epoch(self, + runner: Runner, + metrics: Optional[Dict[str, float]] = None) -> None: + """We recover source model's parameter from ema model after test. + + Args: + runner (Runner): The runner of the testing process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on test dataset. The keys are the names of the + metrics, and the values are corresponding results. + """ + if self.evaluate_on_ema: + # Swap when evaluate on ema + self._swap_ema_parameters() + + if self.evaluate_on_ema and self.evaluate_on_origin: + # Re-evaluate if evaluate on both ema and origin. + MMLogger.get_current_instance().info( + 'Start testing on the original model.') + test_loop = runner.test_loop + + runner.model.eval() + for idx, data_batch in enumerate(test_loop.dataloader): + test_loop.run_iter(idx, data_batch) + + # compute metrics + origin_metrics = test_loop.evaluator.evaluate( + len(test_loop.dataloader.dataset)) + + for k, v in origin_metrics.items(): + runner.message_hub.update_scalar(f'test/{k}_origin', v) + + def after_load_checkpoint(self, runner, checkpoint: dict) -> None: + """Resume ema parameters from checkpoint. + + Args: + runner (Runner): The runner of the testing process. + """ + from mmengine.runner.checkpoint import load_state_dict + if 'ema_state_dict' in checkpoint: + # The original model parameters are actually saved in ema + # field swap the weights back to resume ema state. + self._swap_ema_state_dict(checkpoint) + self.ema_model.load_state_dict( + checkpoint['ema_state_dict'], strict=self.strict_load) + self.load_ema_from_ckpt = True + + # Support load checkpoint without ema state dict. + else: + load_state_dict( + self.ema_model.module, + copy.deepcopy(checkpoint['state_dict']), + strict=self.strict_load) + + @property + def _src_params(self): + if self.ema_model.update_buffers: + return itertools.chain(self.src_model.parameters(), + self.src_model.buffers()) + else: + return self.src_model.parameters() + + @property + def _ema_params(self): + if self.ema_model.update_buffers: + return itertools.chain(self.ema_model.module.parameters(), + self.ema_model.module.buffers()) + else: + return self.ema_model.module.parameters() diff --git a/mmpretrain/engine/hooks/margin_head_hooks.py b/mmpretrain/engine/hooks/margin_head_hooks.py new file mode 100644 index 0000000..fbeae7a --- /dev/null +++ b/mmpretrain/engine/hooks/margin_head_hooks.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved +import numpy as np +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmpretrain.models.heads import ArcFaceClsHead +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class SetAdaptiveMarginsHook(Hook): + r"""Set adaptive-margins in ArcFaceClsHead based on the power of + category-wise count. + + A PyTorch implementation of paper `Google Landmark Recognition 2020 + Competition Third Place Solution `_. + The margins will be + :math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`. + The `n` indicates the number of occurrences of a category. + + Args: + margin_min (float): Lower bound of margins. Defaults to 0.05. + margin_max (float): Upper bound of margins. Defaults to 0.5. + power (float): The power of category freqercy. Defaults to -0.25. + """ + + def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None: + self.margin_min = margin_min + self.margin_max = margin_max + self.margin_range = margin_max - margin_min + self.p = power + + def before_train(self, runner): + """change the margins in ArcFaceClsHead. + + Args: + runner (obj: `Runner`): Runner. + """ + model = runner.model + if is_model_wrapper(model): + model = model.module + + if (hasattr(model, 'head') + and not isinstance(model.head, ArcFaceClsHead)): + raise ValueError( + 'Hook ``SetFreqPowAdvMarginsHook`` could only be used ' + f'for ``ArcFaceClsHead``, but get {type(model.head)}') + + # generate margins base on the dataset. + gt_labels = runner.train_dataloader.dataset.get_gt_labels() + label_count = np.bincount(gt_labels) + label_count[label_count == 0] = 1 # At least one occurrence + pow_freq = np.power(label_count, self.p) + + min_f, max_f = pow_freq.min(), pow_freq.max() + normized_pow_freq = (pow_freq - min_f) / (max_f - min_f) + margins = normized_pow_freq * self.margin_range + self.margin_min + + assert len(margins) == runner.model.head.num_classes + + model.head.set_margins(margins) diff --git a/mmpretrain/engine/hooks/precise_bn_hook.py b/mmpretrain/engine/hooks/precise_bn_hook.py new file mode 100644 index 0000000..4fb0e4c --- /dev/null +++ b/mmpretrain/engine/hooks/precise_bn_hook.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from https://github.com/facebookresearch/pycls/blob/f8cd962737e33ce9e19b3083a33551da95c2d9c0/pycls/core/net.py # noqa: E501 +# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0 # noqa: E501 + +import itertools +import logging +from typing import List, Optional, Sequence, Union + +import mmengine +import torch +import torch.nn as nn +from mmengine.hooks import Hook +from mmengine.logging import print_log +from mmengine.model import is_model_wrapper +from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop, Runner +from mmengine.utils import ProgressBar +from torch.functional import Tensor +from torch.nn import GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.instancenorm import _InstanceNorm +from torch.utils.data import DataLoader + +from mmpretrain.registry import HOOKS + +DATA_BATCH = Optional[Sequence[dict]] + + +def scaled_all_reduce(tensors: List[Tensor], num_gpus: int) -> List[Tensor]: + """Performs the scaled all_reduce operation on the provided tensors. + + The input tensors are modified in-place. Currently supports only the sum + reduction operator. The reduced values are scaled by the inverse size of + the process group. + + Args: + tensors (List[torch.Tensor]): The tensors to process. + num_gpus (int): The number of gpus to use + Returns: + List[torch.Tensor]: The processed tensors. + """ + # There is no need for reduction in the single-proc case + if num_gpus == 1: + return tensors + # Queue the reductions + reductions = [] + for tensor in tensors: + reduction = torch.distributed.all_reduce(tensor, async_op=True) + reductions.append(reduction) + # Wait for reductions to finish + for reduction in reductions: + reduction.wait() + # Scale the results + for tensor in tensors: + tensor.mul_(1.0 / num_gpus) + return tensors + + +@torch.no_grad() +def update_bn_stats( + model: nn.Module, + loader: DataLoader, + num_samples: int = 8192, + logger: Optional[Union[logging.Logger, str]] = None) -> None: + """Computes precise BN stats on training data. + + Args: + model (nn.module): The model whose bn stats will be recomputed. + loader (DataLoader): PyTorch dataloader._dataloader + num_samples (int): The number of samples to update the bn stats. + Defaults to 8192. + logger (logging.Logger or str, optional): If the type of logger is + ``logging.Logger``, we directly use logger to log messages. + Some special loggers are: + - "silent": No message will be printed. + - "current": Use latest created logger to log message. + - other str: Instance name of logger. The corresponding logger + will log message if it has been created, otherwise will raise a + `ValueError`. + - None: The `print()` method will be used to print log messages. + """ + if is_model_wrapper(model): + model = model.module + + # get dist info + rank, world_size = mmengine.dist.get_dist_info() + # Compute the number of mini-batches to use, if the size of dataloader is + # less than num_iters, use all the samples in dataloader. + num_iter = num_samples // (loader.batch_size * world_size) + num_iter = min(num_iter, len(loader)) + # Retrieve the BN layers + bn_layers = [ + m for m in model.modules() + if m.training and isinstance(m, (_BatchNorm)) + ] + if len(bn_layers) == 0: + print_log('No BN found in model', logger=logger, level=logging.WARNING) + return + print_log( + f'{len(bn_layers)} BN found, run {num_iter} iters...', logger=logger) + + # Finds all the other norm layers with training=True. + other_norm_layers = [ + m for m in model.modules() + if m.training and isinstance(m, (_InstanceNorm, GroupNorm)) + ] + if len(other_norm_layers) > 0: + print_log( + 'IN/GN stats will not be updated in PreciseHook.', + logger=logger, + level=logging.INFO) + + # Initialize BN stats storage for computing + # mean(mean(batch)) and mean(var(batch)) + running_means = [torch.zeros_like(bn.running_mean) for bn in bn_layers] + running_vars = [torch.zeros_like(bn.running_var) for bn in bn_layers] + # Remember momentum values + momentums = [bn.momentum for bn in bn_layers] + # Set momentum to 1.0 to compute BN stats that reflect the current batch + for bn in bn_layers: + bn.momentum = 1.0 + # Average the BN stats for each BN layer over the batches + if rank == 0: + prog_bar = ProgressBar(num_iter) + + for data in itertools.islice(loader, num_iter): + data = model.data_preprocessor(data, False) + model(**data) + + for i, bn in enumerate(bn_layers): + running_means[i] += bn.running_mean / num_iter + running_vars[i] += bn.running_var / num_iter + if rank == 0: + prog_bar.update() + + # Sync BN stats across GPUs (no reduction if 1 GPU used) + running_means = scaled_all_reduce(running_means, world_size) + running_vars = scaled_all_reduce(running_vars, world_size) + # Set BN stats and restore original momentum values + for i, bn in enumerate(bn_layers): + bn.running_mean = running_means[i] + bn.running_var = running_vars[i] + bn.momentum = momentums[i] + + +@HOOKS.register_module() +class PreciseBNHook(Hook): + """Precise BN hook. + + Recompute and update the batch norm stats to make them more precise. During + training both BN stats and the weight are changing after every iteration, + so the running average can not precisely reflect the actual stats of the + current model. + + With this hook, the BN stats are recomputed with fixed weights, to make the + running average more precise. Specifically, it computes the true average of + per-batch mean/variance instead of the running average. See Sec. 3 of the + paper `Rethinking Batch in BatchNorm ` + for details. + + This hook will update BN stats, so it should be executed before + ``CheckpointHook`` and ``EMAHook``, generally set its priority to + "ABOVE_NORMAL". + + Args: + num_samples (int): The number of samples to update the bn stats. + Defaults to 8192. + interval (int): Perform precise bn interval. If the train loop is + `EpochBasedTrainLoop` or `by_epoch=True`, its unit is 'epoch'; if the + train loop is `IterBasedTrainLoop` or `by_epoch=False`, its unit is + 'iter'. Defaults to 1. + """ + + def __init__(self, num_samples: int = 8192, interval: int = 1) -> None: + assert interval > 0 and num_samples > 0, "'interval' and " \ + "'num_samples' must be bigger than 0." + + self.interval = interval + self.num_samples = num_samples + + def _perform_precise_bn(self, runner: Runner) -> None: + """perform precise bn.""" + print_log( + f'Running Precise BN for {self.num_samples} samples...', + logger=runner.logger) + update_bn_stats( + runner.model, + runner.train_loop.dataloader, + self.num_samples, + logger=runner.logger) + print_log('Finish Precise BN, BN stats updated.', logger=runner.logger) + + def after_train_epoch(self, runner: Runner) -> None: + """Calculate prcise BN and broadcast BN stats across GPUs. + + Args: + runner (obj:`Runner`): The runner of the training process. + """ + # if use `EpochBasedTrainLoop``, do perform precise every + # `self.interval` epochs. + if isinstance(runner.train_loop, + EpochBasedTrainLoop) and self.every_n_epochs( + runner, self.interval): + self._perform_precise_bn(runner) + + def after_train_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: + """Calculate prcise BN and broadcast BN stats across GPUs. + + Args: + runner (obj:`Runner`): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. + """ + # if use `IterBasedTrainLoop``, do perform precise every + # `self.interval` iters. + if isinstance(runner.train_loop, + IterBasedTrainLoop) and self.every_n_train_iters( + runner, self.interval): + self._perform_precise_bn(runner) diff --git a/mmpretrain/engine/hooks/retriever_hooks.py b/mmpretrain/engine/hooks/retriever_hooks.py new file mode 100644 index 0000000..6bd7c7a --- /dev/null +++ b/mmpretrain/engine/hooks/retriever_hooks.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved +import warnings + +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmpretrain.models import BaseRetriever +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class PrepareProtoBeforeValLoopHook(Hook): + """The hook to prepare the prototype in retrievers. + + Since the encoders of the retriever changes during training, the prototype + changes accordingly. So the `prototype_vecs` needs to be regenerated before + validation loop. + """ + + def before_val(self, runner) -> None: + model = runner.model + if is_model_wrapper(model): + model = model.module + + if isinstance(model, BaseRetriever): + if hasattr(model, 'prepare_prototype'): + model.prepare_prototype() + else: + warnings.warn( + 'Only the `mmpretrain.models.retrievers.BaseRetriever` ' + 'can execute `PrepareRetrieverPrototypeHook`, but got ' + f'`{type(model)}`') diff --git a/mmpretrain/engine/hooks/simsiam_hook.py b/mmpretrain/engine/hooks/simsiam_hook.py new file mode 100644 index 0000000..fabc4fa --- /dev/null +++ b/mmpretrain/engine/hooks/simsiam_hook.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class SimSiamHook(Hook): + """Hook for SimSiam. + + This hook is for SimSiam to fix learning rate of predictor. + + Args: + fix_pred_lr (bool): whether to fix the lr of predictor or not. + lr (float): the value of fixed lr. + adjust_by_epoch (bool, optional): whether to set lr by epoch or iter. + Defaults to True. + """ + + def __init__(self, + fix_pred_lr: bool, + lr: float, + adjust_by_epoch: Optional[bool] = True) -> None: + self.fix_pred_lr = fix_pred_lr + self.lr = lr + self.adjust_by_epoch = adjust_by_epoch + + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: Optional[Sequence[dict]] = None) -> None: + """fix lr of predictor by iter.""" + if self.adjust_by_epoch: + return + else: + if self.fix_pred_lr: + for param_group in runner.optim_wrapper.optimizer.param_groups: + if 'fix_lr' in param_group and param_group['fix_lr']: + param_group['lr'] = self.lr + + def before_train_epoch(self, runner) -> None: + """fix lr of predictor by epoch.""" + if self.fix_pred_lr: + for param_group in runner.optim_wrapper.optimizer.param_groups: + if 'fix_lr' in param_group and param_group['fix_lr']: + param_group['lr'] = self.lr diff --git a/mmpretrain/engine/hooks/swav_hook.py b/mmpretrain/engine/hooks/swav_hook.py new file mode 100644 index 0000000..be5f3a3 --- /dev/null +++ b/mmpretrain/engine/hooks/swav_hook.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Dict, List, Optional, Sequence + +import torch +from mmengine.device import get_device +from mmengine.dist import get_rank, get_world_size, is_distributed +from mmengine.hooks import Hook +from mmengine.logging import MMLogger + +from mmpretrain.registry import HOOKS +from mmpretrain.utils import get_ori_model + + +@HOOKS.register_module() +class SwAVHook(Hook): + """Hook for SwAV. + + This hook builds the queue in SwAV according to ``epoch_queue_starts``. + The queue will be saved in ``runner.work_dir`` or loaded at start epoch + if the path folder has queues saved before. + + Args: + batch_size (int): the batch size per GPU for computing. + epoch_queue_starts (int, optional): from this epoch, starts to use the + queue. Defaults to 15. + crops_for_assign (list[int], optional): list of crops id used for + computing assignments. Defaults to [0, 1]. + feat_dim (int, optional): feature dimension of output vector. + Defaults to 128. + queue_length (int, optional): length of the queue (0 for no queue). + Defaults to 0. + interval (int, optional): the interval to save the queue. + Defaults to 1. + frozen_layers_cfg (dict, optional): Dict to config frozen layers. + The key-value pair is layer name and its frozen iters. If frozen, + the layers don't need gradient. Defaults to dict(). + """ + + def __init__( + self, + batch_size: int, + epoch_queue_starts: Optional[int] = 15, + crops_for_assign: Optional[List[int]] = [0, 1], + feat_dim: Optional[int] = 128, + queue_length: Optional[int] = 0, + interval: Optional[int] = 1, + frozen_layers_cfg: Optional[Dict] = dict() + ) -> None: + self.batch_size = batch_size * get_world_size() + self.epoch_queue_starts = epoch_queue_starts + self.crops_for_assign = crops_for_assign + self.feat_dim = feat_dim + self.queue_length = queue_length + self.interval = interval + self.frozen_layers_cfg = frozen_layers_cfg + self.requires_grad = True + self.queue = None + + def before_run(self, runner) -> None: + """Check whether the queues exist locally or not.""" + if is_distributed(): + self.queue_path = osp.join(runner.work_dir, + 'queue' + str(get_rank()) + '.pth') + else: + self.queue_path = osp.join(runner.work_dir, 'queue.pth') + + # load the queues if queues exist locally + if osp.isfile(self.queue_path): + self.queue = torch.load(self.queue_path)['queue'] + get_ori_model(runner.model).head.loss_module.queue = self.queue + MMLogger.get_current_instance().info( + f'Load queue from file: {self.queue_path}') + + # the queue needs to be divisible by the batch size + self.queue_length -= self.queue_length % self.batch_size + + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: Optional[Sequence[dict]] = None) -> None: + """Freeze layers before specific iters according to the config.""" + for layer, frozen_iters in self.frozen_layers_cfg.items(): + if runner.iter < frozen_iters and self.requires_grad: + self.requires_grad = False + for name, p in get_ori_model(runner.model).named_parameters(): + if layer in name: + p.requires_grad = False + elif runner.iter >= frozen_iters and not self.requires_grad: + self.requires_grad = True + for name, p in get_ori_model(runner.model).named_parameters(): + if layer in name: + p.requires_grad = True + + def before_train_epoch(self, runner) -> None: + """Check the queues' state.""" + # optionally starts a queue + if self.queue_length > 0 \ + and runner.epoch >= self.epoch_queue_starts \ + and self.queue is None: + + self.queue = torch.zeros( + len(self.crops_for_assign), + self.queue_length // runner.world_size, + self.feat_dim, + device=get_device(), + ) + + # set the boolean type of use_the_queue + get_ori_model(runner.model).head.loss_module.queue = self.queue + get_ori_model(runner.model).head.loss_module.use_queue = False + + def after_train_epoch(self, runner) -> None: + """Save the queues locally.""" + self.queue = get_ori_model(runner.model).head.loss_module.queue + + if self.queue is not None and self.every_n_epochs( + runner, self.interval): + torch.save({'queue': self.queue}, self.queue_path) diff --git a/mmpretrain/engine/hooks/switch_recipe_hook.py b/mmpretrain/engine/hooks/switch_recipe_hook.py new file mode 100644 index 0000000..914b957 --- /dev/null +++ b/mmpretrain/engine/hooks/switch_recipe_hook.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from copy import deepcopy + +from mmcv.transforms import Compose +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmpretrain.models.utils import RandomBatchAugment +from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS + + +@HOOKS.register_module() +class SwitchRecipeHook(Hook): + """switch recipe during the training loop, including train pipeline, batch + augments and loss currently. + + Args: + schedule (list): Every item of the schedule list should be a dict, and + the dict should have ``action_epoch`` and some of + ``train_pipeline``, ``train_augments`` and ``loss`` keys: + + - ``action_epoch`` (int): switch training recipe at which epoch. + - ``train_pipeline`` (list, optional): The new data pipeline of the + train dataset. If not specified, keep the original settings. + - ``batch_augments`` (dict | None, optional): The new batch + augmentations of during training. See :mod:`Batch Augmentations + ` for more details. + If None, disable batch augmentations. If not specified, keep the + original settings. + - ``loss`` (dict, optional): The new loss module config. If not + specified, keep the original settings. + + Example: + To use this hook in config files. + + .. code:: python + + custom_hooks = [ + dict( + type='SwitchRecipeHook', + schedule=[ + dict( + action_epoch=30, + train_pipeline=pipeline_after_30e, + batch_augments=batch_augments_after_30e, + loss=loss_after_30e, + ), + dict( + action_epoch=60, + # Disable batch augmentations after 60e + # and keep other settings. + batch_augments=None, + ), + ] + ) + ] + """ + priority = 'NORMAL' + + def __init__(self, schedule): + recipes = {} + for recipe in schedule: + assert 'action_epoch' in recipe, \ + 'Please set `action_epoch` in every item ' \ + 'of the `schedule` in the SwitchRecipeHook.' + recipe = deepcopy(recipe) + if 'train_pipeline' in recipe: + recipe['train_pipeline'] = Compose(recipe['train_pipeline']) + if 'batch_augments' in recipe: + batch_augments = recipe['batch_augments'] + if isinstance(batch_augments, dict): + batch_augments = RandomBatchAugment(**batch_augments) + recipe['batch_augments'] = batch_augments + if 'loss' in recipe: + loss = recipe['loss'] + if isinstance(loss, dict): + loss = MODELS.build(loss) + recipe['loss'] = loss + + action_epoch = recipe.pop('action_epoch') + assert action_epoch not in recipes, \ + f'The `action_epoch` {action_epoch} is repeated ' \ + 'in the SwitchRecipeHook.' + recipes[action_epoch] = recipe + self.schedule = OrderedDict(sorted(recipes.items())) + + def before_train(self, runner) -> None: + """before run setting. If resume form a checkpoint, do all switch + before the current epoch. + + Args: + runner (Runner): The runner of the training, validation or testing + process. + """ + if runner._resume: + for action_epoch, recipe in self.schedule.items(): + if action_epoch >= runner.epoch + 1: + break + self._do_switch(runner, recipe, + f' (resume recipe of epoch {action_epoch})') + + def before_train_epoch(self, runner): + """do before train epoch.""" + recipe = self.schedule.get(runner.epoch + 1, None) + if recipe is not None: + self._do_switch(runner, recipe, f' at epoch {runner.epoch + 1}') + + def _do_switch(self, runner, recipe, extra_info=''): + """do the switch aug process.""" + if 'batch_augments' in recipe: + self._switch_batch_augments(runner, recipe['batch_augments']) + runner.logger.info(f'Switch batch augments{extra_info}.') + + if 'train_pipeline' in recipe: + self._switch_train_pipeline(runner, recipe['train_pipeline']) + runner.logger.info(f'Switch train pipeline{extra_info}.') + + if 'loss' in recipe: + self._switch_loss(runner, recipe['loss']) + runner.logger.info(f'Switch loss{extra_info}.') + + @staticmethod + def _switch_batch_augments(runner, batch_augments): + """switch the train augments.""" + model = runner.model + if is_model_wrapper(model): + model = model.module + + model.data_preprocessor.batch_augments = batch_augments + + @staticmethod + def _switch_train_pipeline(runner, train_pipeline): + """switch the train loader dataset pipeline.""" + + def switch_pipeline(dataset, pipeline): + if hasattr(dataset, 'pipeline'): + # for usual dataset + dataset.pipeline = pipeline + elif hasattr(dataset, 'datasets'): + # for concat dataset wrapper + for ds in dataset.datasets: + switch_pipeline(ds, pipeline) + elif hasattr(dataset, 'dataset'): + # for other dataset wrappers + switch_pipeline(dataset.dataset, pipeline) + else: + raise RuntimeError( + 'Cannot access the `pipeline` of the dataset.') + + train_loader = runner.train_loop.dataloader + switch_pipeline(train_loader.dataset, train_pipeline) + + # To restart the iterator of dataloader when `persistent_workers=True` + train_loader._iterator = None + + @staticmethod + def _switch_loss(runner, loss_module): + """switch the loss module.""" + model = runner.model + if is_model_wrapper(model, MODEL_WRAPPERS): + model = model.module + + if hasattr(model, 'loss_module'): + model.loss_module = loss_module + elif hasattr(model, 'head') and hasattr(model.head, 'loss_module'): + model.head.loss_module = loss_module + else: + raise RuntimeError('Cannot access the `loss_module` of the model.') diff --git a/mmpretrain/engine/hooks/visualization_hook.py b/mmpretrain/engine/hooks/visualization_hook.py new file mode 100644 index 0000000..64d2230 --- /dev/null +++ b/mmpretrain/engine/hooks/visualization_hook.py @@ -0,0 +1,126 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import os.path as osp +from typing import Optional, Sequence + +from mmengine.fileio import join_path +from mmengine.hooks import Hook +from mmengine.runner import EpochBasedTrainLoop, Runner +from mmengine.visualization import Visualizer + +from mmpretrain.registry import HOOKS +from mmpretrain.structures import DataSample + + +@HOOKS.register_module() +class VisualizationHook(Hook): + """Classification Visualization Hook. Used to visualize validation and + testing prediction results. + + - If ``out_dir`` is specified, all storage backends are ignored + and save the image to the ``out_dir``. + - If ``show`` is True, plot the result image in a window, please + confirm you are able to access the graphical interface. + + Args: + enable (bool): Whether to enable this hook. Defaults to False. + interval (int): The interval of samples to visualize. Defaults to 5000. + show (bool): Whether to display the drawn image. Defaults to False. + out_dir (str, optional): directory where painted images will be saved + in the testing process. If None, handle with the backends of the + visualizer. Defaults to None. + **kwargs: other keyword arguments of + :meth:`mmpretrain.visualization.UniversalVisualizer.visualize_cls`. + """ + + def __init__(self, + enable=False, + interval: int = 5000, + show: bool = False, + out_dir: Optional[str] = None, + **kwargs): + self._visualizer: Visualizer = Visualizer.get_current_instance() + + self.enable = enable + self.interval = interval + self.show = show + self.out_dir = out_dir + + self.draw_args = {**kwargs, 'show': show} + + def _draw_samples(self, + batch_idx: int, + data_batch: dict, + data_samples: Sequence[DataSample], + step: int = 0) -> None: + """Visualize every ``self.interval`` samples from a data batch. + + Args: + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DataSample`]): Outputs from model. + step (int): Global step value to record. Defaults to 0. + """ + if self.enable is False: + return + + batch_size = len(data_samples) + images = data_batch['inputs'] + start_idx = batch_size * batch_idx + end_idx = start_idx + batch_size + + # The first index divisible by the interval, after the start index + first_sample_id = math.ceil(start_idx / self.interval) * self.interval + + for sample_id in range(first_sample_id, end_idx, self.interval): + image = images[sample_id - start_idx] + image = image.permute(1, 2, 0).cpu().numpy().astype('uint8') + + data_sample = data_samples[sample_id - start_idx] + if 'img_path' in data_sample: + # osp.basename works on different platforms even file clients. + sample_name = osp.basename(data_sample.get('img_path')) + else: + sample_name = str(sample_id) + + draw_args = self.draw_args + if self.out_dir is not None: + draw_args['out_file'] = join_path(self.out_dir, + f'{sample_name}_{step}.png') + + self._visualizer.visualize_cls( + image=image, + data_sample=data_sample, + step=step, + name=sample_name, + **self.draw_args, + ) + + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[DataSample]) -> None: + """Visualize every ``self.interval`` samples during validation. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DataSample`]): Outputs from model. + """ + if isinstance(runner.train_loop, EpochBasedTrainLoop): + step = runner.epoch + else: + step = runner.iter + + self._draw_samples(batch_idx, data_batch, outputs, step=step) + + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[DataSample]) -> None: + """Visualize every ``self.interval`` samples during test. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the test loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DetDataSample`]): Outputs from model. + """ + self._draw_samples(batch_idx, data_batch, outputs, step=0) diff --git a/mmpretrain/engine/hooks/warmup_param_hook.py b/mmpretrain/engine/hooks/warmup_param_hook.py new file mode 100644 index 0000000..b45d891 --- /dev/null +++ b/mmpretrain/engine/hooks/warmup_param_hook.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import operator as op +from typing import Any, Optional, Union + +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS +from mmpretrain.utils import get_ori_model + + +@HOOKS.register_module() +class WarmupParamHook(Hook): + """This is a hook used for changing the parameters other than optimizations + that need to warmup inside the module. + + This hook can extend with more detailed warmup rule if necessary. + + Args: + param_name (str): The parameter name that needs to be altered. + module_name (str): Module name that belongs to the model. Such as + `head`, `head.loss`, etc. + warmup_epochs (int): The warmup epochs for this parameter. + """ + + def __init__( + self, + param_name: str, + module_name: str, + warmup_epochs: int, + ) -> None: + self.param_name = param_name + self.warmup_epochs = warmup_epochs + # getter for module which saves the changed parameter + self.module_getter = op.attrgetter(module_name) + + def get_param(self, runner) -> Any: + """Get the parameter.""" + try: + module = self.module_getter(get_ori_model(runner.model)) + return getattr(module, self.param_name) + except AttributeError as e: + raise AttributeError(f'{e}. Please check hook settings.') + + def set_param(self, runner, value) -> None: + """Set the parameter.""" + try: + module = self.module_getter(get_ori_model(runner.model)) + setattr(module, self.param_name, value) + except AttributeError as e: + raise AttributeError(f'{e}. Please check hook settings.') + + def before_train(self, runner) -> None: + """Get the original value before train.""" + self.ori_val = self.get_param(runner) + + def before_train_iter( + self, + runner, + batch_idx: int, + data_batch: Optional[Union[dict, tuple, list]] = None) -> None: + """Set the warmup value before each train iter.""" + cur_iter = runner.iter + iters_per_epoch = runner.max_iters / runner.max_epochs + new_val = self.ori_val * min( + 1, cur_iter / (self.warmup_epochs * iters_per_epoch)) + self.set_param(runner, new_val) diff --git a/mmpretrain/engine/optimizers/__init__.py b/mmpretrain/engine/optimizers/__init__.py new file mode 100644 index 0000000..bd53a37 --- /dev/null +++ b/mmpretrain/engine/optimizers/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .adan_t import Adan +from .lamb import Lamb +from .lars import LARS +from .layer_decay_optim_wrapper_constructor import \ + LearningRateDecayOptimWrapperConstructor + +__all__ = ['Lamb', 'Adan', 'LARS', 'LearningRateDecayOptimWrapperConstructor'] diff --git a/mmpretrain/engine/optimizers/adan_t.py b/mmpretrain/engine/optimizers/adan_t.py new file mode 100644 index 0000000..571a71b --- /dev/null +++ b/mmpretrain/engine/optimizers/adan_t.py @@ -0,0 +1,312 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + +from mmpretrain.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module() +class Adan(Optimizer): + """Implements a pytorch variant of Adan. + + Adan was proposed in + Adan : Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models. # noqa + https://arxiv.org/abs/2208.06677 + Arguments: + params (iterable): iterable of parameters to optimize + or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used + for computing running averages of gradient. + (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + no_prox (bool): how to perform the decoupled weight decay + (default: False) + foreach (bool): if True would use torch._foreach implementation. + It's faster but uses slightly more memory. + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.98, 0.92, 0.99), + eps=1e-8, + weight_decay=0.0, + max_grad_norm=0.0, + no_prox=False, + foreach: bool = True): + if not 0.0 <= max_grad_norm: + raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm)) + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= eps: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError('Invalid beta parameter at index 2: {}'.format( + betas[2])) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + no_prox=no_prox, + foreach=foreach) + super().__init__(params, defaults) + + def __setstate__(self, state): + super(Adan, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('no_prox', False) + + @torch.no_grad() + def restart_opt(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + if p.requires_grad: + state = self.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self): + """Performs a single optimization step.""" + if self.defaults['max_grad_norm'] > 0: + device = self.param_groups[0]['params'][0].device + global_grad_norm = torch.zeros(1, device=device) + + max_grad_norm = torch.tensor( + self.defaults['max_grad_norm'], device=device) + for group in self.param_groups: + + for p in group['params']: + if p.grad is not None: + grad = p.grad + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + group['eps'] + + clip_global_grad_norm = \ + torch.clamp(max_grad_norm / global_grad_norm, max=1.0) + else: + clip_global_grad_norm = 1.0 + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + exp_avg_diffs = [] + pre_grads = [] + + beta1, beta2, beta3 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support + # by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1 = 1.0 - beta1**group['step'] + bias_correction2 = 1.0 - beta2**group['step'] + bias_correction3 = 1.0 - beta3**group['step'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_diff'] = torch.zeros_like(p) + + if 'pre_grad' not in state or group['step'] == 1: + # at first step grad wouldn't be clipped + # by `clip_global_grad_norm` + # this is only to simplify implementation + state['pre_grad'] = p.grad + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + exp_avg_diffs.append(state['exp_avg_diff']) + pre_grads.append(state['pre_grad']) + + kwargs = dict( + params=params_with_grad, + grads=grads, + exp_avgs=exp_avgs, + exp_avg_sqs=exp_avg_sqs, + exp_avg_diffs=exp_avg_diffs, + pre_grads=pre_grads, + beta1=beta1, + beta2=beta2, + beta3=beta3, + bias_correction1=bias_correction1, + bias_correction2=bias_correction2, + bias_correction3_sqrt=math.sqrt(bias_correction3), + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + no_prox=group['no_prox'], + clip_global_grad_norm=clip_global_grad_norm, + ) + if group['foreach']: + copy_grads = _multi_tensor_adan(**kwargs) + else: + copy_grads = _single_tensor_adan(**kwargs) + + for p, copy_grad in zip(params_with_grad, copy_grads): + self.state[p]['pre_grad'] = copy_grad + + +def _single_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + copy_grads = [] + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg_diff = exp_avg_diffs[i] + pre_grad = pre_grads[i] + + grad = grad.mul_(clip_global_grad_norm) + copy_grads.append(grad.clone()) + + diff = grad - pre_grad + update = grad + beta2 * diff + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t + exp_avg_diff.mul_(beta2).add_(diff, alpha=1 - beta2) # diff_t + exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1 - beta3) # n_t + + denom = (exp_avg_sq.sqrt() / bias_correction3_sqrt).add_(eps) + update = exp_avg / bias_correction1 + update.add_(beta2 * exp_avg_diff / bias_correction2).div_(denom) + + if no_prox: + param.mul_(1 - lr * weight_decay) + param.add_(update, alpha=-lr) + else: + param.add_(update, alpha=-lr) + param.div_(1 + lr * weight_decay) + return copy_grads + + +def _multi_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + if clip_global_grad_norm < 1.0: + torch._foreach_mul_(grads, clip_global_grad_norm.item()) + copy_grads = [g.clone() for g in grads] + + diff = torch._foreach_sub(grads, pre_grads) + # NOTE: line below while looking identical gives different result, + # due to float precision errors. + # using mul+add produces identical results to single-tensor, + # using add+alpha doesn't + # update = torch._foreach_add(grads, torch._foreach_mul(diff, beta2)) + update = torch._foreach_add(grads, diff, alpha=beta2) + + torch._foreach_mul_(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t + + torch._foreach_mul_(exp_avg_diffs, beta2) + torch._foreach_add_(exp_avg_diffs, diff, alpha=1 - beta2) # diff_t + + torch._foreach_mul_(exp_avg_sqs, beta3) + torch._foreach_addcmul_( + exp_avg_sqs, update, update, value=1 - beta3) # n_t + + denom = torch._foreach_sqrt(exp_avg_sqs) + torch._foreach_div_(denom, bias_correction3_sqrt) + torch._foreach_add_(denom, eps) + + update = torch._foreach_div(exp_avgs, bias_correction1) + # NOTE: same issue as above. + # beta2 * diff / bias_correction2 != diff * (beta2 / bias_correction2) # noqa + # using faster version by default. uncomment for tests to pass + # torch._foreach_add_(update, torch._foreach_div(torch._foreach_mul(exp_avg_diffs, beta2), bias_correction2)) # noqa + torch._foreach_add_( + update, torch._foreach_mul(exp_avg_diffs, beta2 / bias_correction2)) + torch._foreach_div_(update, denom) + + if no_prox: + torch._foreach_mul_(params, 1 - lr * weight_decay) + else: + torch._foreach_add_(params, update, alpha=-lr) + torch._foreach_div_(params, 1 + lr * weight_decay) + return copy_grads diff --git a/mmpretrain/engine/optimizers/lamb.py b/mmpretrain/engine/optimizers/lamb.py new file mode 100644 index 0000000..0b44a1c --- /dev/null +++ b/mmpretrain/engine/optimizers/lamb.py @@ -0,0 +1,228 @@ +"""PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb. + +This optimizer code was adapted from the following (starting with latest) +* https://github.com/HabanaAI/Model-References/blob/ +2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/ +LanguageModeling/Transformer-XL/pytorch/lamb.py +* https://github.com/cybertronai/pytorch-lamb + +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb +is to have a version that is +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or +cannot install/use APEX. + +In addition to some cleanup, this Lamb impl has been modified to support +PyTorch XLA and has been tested on TPU. + +Original copyrights for above sources are below. + +Modifications Copyright 2021 Ross Wightman +""" +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math + +import torch +from torch.optim import Optimizer + +from mmpretrain.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module() +class Lamb(Optimizer): + """A pure pytorch variant of FuseLAMB (NvLamb variant) optimizer. + + This class is copied from `timm`_. The LAMB was proposed in `Large Batch + Optimization for Deep Learning - Training BERT in 76 minutes`_. + + .. _timm: + https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm + (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + """ # noqa: E501 + + def __init__(self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + grad_averaging=True, + max_grad_norm=1.0, + trust_clip=False, + always_adapt=False): + defaults = dict( + lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm, + trust_clip=trust_clip, + always_adapt=always_adapt) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor( + 1.0, device=device + ) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + 'Lamb does not support sparse gradients, consider ' + 'SparseAdam instead.') + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars + # when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor( + self.defaults['max_grad_norm'], device=device) + clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor) + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or + # pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1**group['step'] + bias_correction2 = 1 - beta2**group['step'] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( + group['eps']) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group['weight_decay'] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group['always_adapt']: + # Layer-wise LR adaptation. By default, skip adaptation on + # parameters that are + # excluded from weight decay, unless always_adapt == True, + # then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not + # working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group['trust_clip']: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group['lr']) + + return loss diff --git a/mmpretrain/engine/optimizers/lars.py b/mmpretrain/engine/optimizers/lars.py new file mode 100644 index 0000000..5e38887 --- /dev/null +++ b/mmpretrain/engine/optimizers/lars.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterable + +import torch +from torch.optim.optimizer import Optimizer + +from mmpretrain.registry import OPTIMIZERS + + +@OPTIMIZERS.register_module() +class LARS(Optimizer): + """Implements layer-wise adaptive rate scaling for SGD. + + Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. + `Large Batch Training of Convolutional Networks: + `_. + + Args: + params (Iterable): Iterable of parameters to optimize or dicts defining + parameter groups. + lr (float): Base learning rate. + momentum (float): Momentum factor. Defaults to 0. + weight_decay (float): Weight decay (L2 penalty). Defaults to 0. + dampening (float): Dampening for momentum. Defaults to 0. + eta (float): LARS coefficient. Defaults to 0.001. + nesterov (bool): Enables Nesterov momentum. Defaults to False. + eps (float): A small number to avoid dviding zero. Defaults to 1e-8. + + Example: + >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9, + >>> weight_decay=1e-4, eta=1e-3) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + """ + + def __init__(self, + params: Iterable, + lr: float, + momentum: float = 0, + weight_decay: float = 0, + dampening: float = 0, + eta: float = 0.001, + nesterov: bool = False, + eps: float = 1e-8) -> None: + if not isinstance(lr, float) and lr < 0.0: + raise ValueError(f'Invalid learning rate: {lr}') + if momentum < 0.0: + raise ValueError(f'Invalid momentum value: {momentum}') + if weight_decay < 0.0: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + if eta < 0.0: + raise ValueError(f'Invalid LARS coefficient value: {eta}') + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + eta=eta) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError( + 'Nesterov momentum requires a momentum and zero dampening') + + self.eps = eps + super().__init__(params, defaults) + + def __setstate__(self, state) -> None: + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + @torch.no_grad() + def step(self, closure=None) -> torch.Tensor: + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + eta = group['eta'] + nesterov = group['nesterov'] + lr = group['lr'] + lars_exclude = group.get('lars_exclude', False) + + for p in group['params']: + if p.grad is None: + continue + + d_p = p.grad + + if lars_exclude: + local_lr = 1. + else: + weight_norm = torch.norm(p).item() + grad_norm = torch.norm(d_p).item() + if weight_norm != 0 and grad_norm != 0: + # Compute local learning rate for this layer + local_lr = eta * weight_norm / \ + (grad_norm + weight_decay * weight_norm + self.eps) + else: + local_lr = 1. + + actual_lr = local_lr * lr + d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr) + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = \ + torch.clone(d_p).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + p.add_(-d_p) + + return loss diff --git a/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py b/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py new file mode 100644 index 0000000..09c6abc --- /dev/null +++ b/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from typing import Callable, List, Optional + +from mmengine.logging import MMLogger +from mmengine.optim import DefaultOptimWrapperConstructor +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm +from torch import nn +from torch.nn import GroupNorm, LayerNorm + +from mmpretrain.registry import OPTIM_WRAPPER_CONSTRUCTORS + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor): + """Different learning rates are set for different layers of backbone. + + By default, each parameter share the same optimizer settings, and we + provide an argument ``paramwise_cfg`` to specify parameter-wise settings. + It is a dict and may contain the following fields: + + - ``layer_decay_rate`` (float): The learning rate of a parameter will + multiply it by multiple times according to the layer depth of the + parameter. Usually, it's less than 1, so that the earlier layers will + have a lower learning rate. Defaults to 1. + - ``bias_decay_mult`` (float): It will be multiplied to the weight + decay for all bias parameters (except for those in normalization layers). + - ``norm_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of normalization layers. + - ``flat_decay_mult`` (float): It will be multiplied to the weight + decay for all one-dimensional parameters + - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If + one of the keys in ``custom_keys`` is a substring of the name of one + parameter, then the setting of the parameter will be specified by + ``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be + ignored. It should be a dict and may contain fields ``decay_mult``. + (The ``lr_mult`` is disabled in this constructor). + + Example: + + In the config file, you can use this constructor as below: + + .. code:: python + + optim_wrapper = dict( + optimizer=dict( + type='AdamW', + lr=4e-3, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999)), + constructor='LearningRateDecayOptimWrapperConstructor', + paramwise_cfg=dict( + layer_decay_rate=0.75, # layer-wise lr decay factor + norm_decay_mult=0., + flat_decay_mult=0., + custom_keys={ + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0) + })) + """ + + def add_params(self, + params: List[dict], + module: nn.Module, + prefix: str = '', + get_layer_depth: Optional[Callable] = None, + **kwargs) -> None: + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (List[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + optimizer_cfg (dict): The configuration of optimizer. + prefix (str): The prefix of the module. + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + logger = MMLogger.get_current_instance() + + # The model should have `get_layer_depth` method + if get_layer_depth is None and not hasattr(module, 'get_layer_depth'): + raise NotImplementedError('The layer-wise learning rate decay need' + f' the model {type(module)} has' + ' `get_layer_depth` method.') + else: + get_layer_depth = get_layer_depth or module.get_layer_depth + + bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None) + norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None) + flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None) + decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, + (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + + for name, param in module.named_parameters(recurse=False): + param_group = {'params': [param]} + param_name = prefix + name + if not param.requires_grad: + continue + + if self.base_wd is not None: + base_wd = self.base_wd + custom_key = next( + filter(lambda k: k in param_name, sorted_keys), None) + # custom parameters decay + if custom_key is not None: + custom_cfg = custom_keys[custom_key].copy() + decay_mult = custom_cfg.pop('decay_mult', 1.) + + param_group['weight_decay'] = base_wd * decay_mult + # add custom settings to param_group + param_group.update(custom_cfg) + # norm decay + elif is_norm and norm_decay_mult is not None: + param_group['weight_decay'] = base_wd * norm_decay_mult + # bias decay + elif name == 'bias' and bias_decay_mult is not None: + param_group['weight_decay'] = base_wd * bias_decay_mult + # flatten parameters decay + elif param.ndim == 1 and flat_decay_mult is not None: + param_group['weight_decay'] = base_wd * flat_decay_mult + else: + param_group['weight_decay'] = base_wd + + layer_id, max_id = get_layer_depth(param_name) + scale = decay_rate**(max_id - layer_id - 1) + param_group['lr'] = self.base_lr * scale + param_group['lr_scale'] = scale + param_group['layer_id'] = layer_id + param_group['param_name'] = param_name + + params.append(param_group) + + for child_name, child_mod in module.named_children(): + child_prefix = f'{prefix}{child_name}.' + self.add_params( + params, + child_mod, + prefix=child_prefix, + get_layer_depth=get_layer_depth, + ) + + if prefix == '': + layer_params = defaultdict(list) + for param in params: + layer_params[param['layer_id']].append(param) + for layer_id, layer_params in layer_params.items(): + lr_scale = layer_params[0]['lr_scale'] + lr = layer_params[0]['lr'] + msg = [ + f'layer {layer_id} params ' + f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):' + ] + for param in layer_params: + msg.append(f'\t{param["param_name"]}: ' + f'weight_decay={param["weight_decay"]:.3g}') + logger.debug('\n'.join(msg)) diff --git a/mmpretrain/engine/runners/__init__.py b/mmpretrain/engine/runners/__init__.py new file mode 100644 index 0000000..23206e1 --- /dev/null +++ b/mmpretrain/engine/runners/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .retrieval_loop import RetrievalTestLoop, RetrievalValLoop + +__all__ = ['RetrievalTestLoop', 'RetrievalValLoop'] diff --git a/mmpretrain/engine/runners/retrieval_loop.py b/mmpretrain/engine/runners/retrieval_loop.py new file mode 100644 index 0000000..d15387e --- /dev/null +++ b/mmpretrain/engine/runners/retrieval_loop.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +from mmengine.model import is_model_wrapper +from mmengine.runner import TestLoop, ValLoop, autocast + +from mmpretrain.registry import LOOPS + + +@LOOPS.register_module() +class RetrievalValLoop(ValLoop): + """Loop for multimodal retrieval val. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 valing. Defaults to + False. + """ + + def run(self) -> dict: + """Launch val.""" + self.runner.call_hook('before_val') + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + + feats_local = [] + data_samples_local = [] + + for idx, data_batch in enumerate(self.dataloader): + with torch.no_grad(): + self.runner.call_hook( + 'before_val_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + if is_model_wrapper(self.runner.model): + data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501 + else: + data_preprocessor = self.runner.model.data_preprocessor + + # get features for retrieval instead of data samples + data_batch = data_preprocessor(data_batch, False) + feats = self.runner.model._run_forward( + data_batch, mode='tensor') + feats_local.append(feats) + data_samples_local.extend(data_batch['data_samples']) + self.runner.call_hook( + 'after_val_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=feats) + + # concatenate different features + feats_local = { + k: torch.cat([dic[k] for dic in feats_local]) + for k in feats_local[0] + } + + # get predictions + if is_model_wrapper(self.runner.model): + predict_all_fn = self.runner.model.module.predict_all + else: + predict_all_fn = self.runner.model.predict_all + + img_size = self.dataloader.dataset.img_size + text_size = self.dataloader.dataset.text_size + with torch.no_grad(): + i2t_data_samples, t2i_data_samples = predict_all_fn( + feats_local, + data_samples_local, + num_images=img_size, + num_texts=text_size, + ) + + # process in evaluator and compute metrics + self.evaluator.process(i2t_data_samples, None) + i2t_metrics = self.evaluator.evaluate(img_size) + i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()} + self.evaluator.process(t2i_data_samples, None) + t2i_metrics = self.evaluator.evaluate(text_size) + t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()} + metrics = {**i2t_metrics, **t2i_metrics} + + self.runner.call_hook('after_val_epoch', metrics=metrics) + self.runner.call_hook('after_val') + return metrics + + +@LOOPS.register_module() +class RetrievalTestLoop(TestLoop): + """Loop for multimodal retrieval test. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 testing. Defaults to + False. + """ + + def run(self) -> dict: + """Launch test.""" + self.runner.call_hook('before_test') + self.runner.call_hook('before_test_epoch') + self.runner.model.eval() + + feats_local = [] + data_samples_local = [] + + for idx, data_batch in enumerate(self.dataloader): + with torch.no_grad(): + self.runner.call_hook( + 'before_test_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + if is_model_wrapper(self.runner.model): + data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501 + else: + data_preprocessor = self.runner.model.data_preprocessor + # get features for retrieval instead of data samples + data_batch = data_preprocessor(data_batch, False) + feats = self.runner.model._run_forward( + data_batch, mode='tensor') + feats_local.append(feats) + data_samples_local.extend(data_batch['data_samples']) + self.runner.call_hook( + 'after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=feats) + + # concatenate different features + feats_local = { + k: torch.cat([dic[k] for dic in feats_local]) + for k in feats_local[0] + } + + # get predictions + if is_model_wrapper(self.runner.model): + predict_all_fn = self.runner.model.module.predict_all + else: + predict_all_fn = self.runner.model.predict_all + + img_size = self.dataloader.dataset.img_size + text_size = self.dataloader.dataset.text_size + with torch.no_grad(): + i2t_data_samples, t2i_data_samples = predict_all_fn( + feats_local, + data_samples_local, + num_images=img_size, + num_texts=text_size, + ) + + # process in evaluator and compute metrics + self.evaluator.process(i2t_data_samples, None) + i2t_metrics = self.evaluator.evaluate(img_size) + i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()} + self.evaluator.process(t2i_data_samples, None) + t2i_metrics = self.evaluator.evaluate(text_size) + t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()} + metrics = {**i2t_metrics, **t2i_metrics} + + self.runner.call_hook('after_test_epoch', metrics=metrics) + self.runner.call_hook('after_test') + return metrics diff --git a/mmpretrain/engine/schedulers/__init__.py b/mmpretrain/engine/schedulers/__init__.py new file mode 100644 index 0000000..68b6a54 --- /dev/null +++ b/mmpretrain/engine/schedulers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .weight_decay_scheduler import CosineAnnealingWeightDecay + +__all__ = ['CosineAnnealingWeightDecay'] diff --git a/mmpretrain/engine/schedulers/weight_decay_scheduler.py b/mmpretrain/engine/schedulers/weight_decay_scheduler.py new file mode 100644 index 0000000..7e725a4 --- /dev/null +++ b/mmpretrain/engine/schedulers/weight_decay_scheduler.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from mmengine.optim.scheduler import CosineAnnealingParamScheduler + +from mmpretrain.registry import PARAM_SCHEDULERS + + +class WeightDecaySchedulerMixin: + """A mixin class for learning rate schedulers.""" + + def __init__(self, optimizer, *args, **kwargs): + super().__init__(optimizer, 'weight_decay', *args, **kwargs) + + +@PARAM_SCHEDULERS.register_module() +class CosineAnnealingWeightDecay(WeightDecaySchedulerMixin, + CosineAnnealingParamScheduler): + """Set the weight decay value of each parameter group using a cosine + annealing schedule. + + If the weight decay was set to be 0 initially, the weight decay value will + be 0 constantly during the training. + """ + + def _get_value(self) -> list: + """Compute value using chainable form of the scheduler.""" + + def _get_eta_min(base_value): + if self.eta_min_ratio is None: + return self.eta_min + return base_value * self.eta_min_ratio + + if self.last_step == 0: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0: + weight_decay_value_list = [] + for base_value, group in zip(self.base_values, + self.optimizer.param_groups): + if base_value == 0: + group_value = 0 + else: + group_value = group[self.param_name] + ( + base_value - _get_eta_min(base_value)) * ( + 1 - math.cos(math.pi / self.T_max)) / 2 + weight_decay_value_list.append(group_value) + return weight_decay_value_list + + weight_decay_value_list = [] + for base_value, group in zip(self.base_values, + self.optimizer.param_groups): + if base_value == 0: + group_value = 0 + else: + group_value = ( + 1 + math.cos(math.pi * self.last_step / self.T_max)) / ( + 1 + math.cos(math.pi * + (self.last_step - 1) / self.T_max) + ) * (group[self.param_name] - + _get_eta_min(base_value)) + _get_eta_min(base_value) + weight_decay_value_list.append(group_value) + return weight_decay_value_list diff --git a/mmpretrain/evaluation/__init__.py b/mmpretrain/evaluation/__init__.py new file mode 100644 index 0000000..f70dc22 --- /dev/null +++ b/mmpretrain/evaluation/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .functional import * # noqa: F401,F403 +from .metrics import * # noqa: F401,F403 diff --git a/mmpretrain/evaluation/functional/__init__.py b/mmpretrain/evaluation/functional/__init__.py new file mode 100644 index 0000000..ef101fe --- /dev/null +++ b/mmpretrain/evaluation/functional/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/evaluation/metrics/ANLS.py b/mmpretrain/evaluation/metrics/ANLS.py new file mode 100644 index 0000000..14917f1 --- /dev/null +++ b/mmpretrain/evaluation/metrics/ANLS.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class ANLS(BaseMetric): + """ANLS metric. + + Compute the Average Normalized Levenshtein Similarity(ANLS). + + Args: + threshold (float): ANLS threshold used for determining if the answer + has been correctly selected but not properly recognized, + or on the contrary, the output is a wrong text selected from the + options and given as an answer. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'ANLS' + + def __init__(self, + threshold: float = 0.5, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.threshold = threshold + + def process(self, data_batch, data_samples) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + gt_answer = sample.get('gt_answer') + result = { + 'pred_answer': sample.get('pred_answer'), + 'gt_answer': gt_answer + } + + self.results.append(result) + + def compute_metrics(self, results: List) -> dict: + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + total_score = 0. + for result in results: + sample_score_list = [] + pred = ' '.join(result['pred_answer'].strip().lower().split()) + for gt in result['gt_answer']: + gt = ' '.join(gt.strip().lower().split()) + dist = levenshtein_distance(gt, pred) + length = max( + len(gt.upper()), len(result['pred_answer'].upper())) + sample_score_list.append(0.0 if length == 0 else float(dist) / + float(length)) + + per_sample_score = 1. - min(sample_score_list) + if per_sample_score < self.threshold: + per_sample_score = 0. + + total_score += per_sample_score + + total_score = total_score / len(results) + return {'ANLS': total_score} + + +def levenshtein_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], + distances_[-1]))) + distances = distances_ + return distances[-1] diff --git a/mmpretrain/evaluation/metrics/__init__.py b/mmpretrain/evaluation/metrics/__init__.py new file mode 100644 index 0000000..e572efe --- /dev/null +++ b/mmpretrain/evaluation/metrics/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ANLS import ANLS +from .caption import COCOCaption +from .gqa import GQAAcc +from .multi_label import AveragePrecision, MultiLabelMetric +from .multi_task import MultiTasksMetric +from .nocaps import NocapsSave +from .retrieval import RetrievalAveragePrecision, RetrievalRecall +from .scienceqa import ScienceQAMetric +from .shape_bias_label import ShapeBiasMetric +from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric +from .visual_grounding_eval import VisualGroundingMetric +from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric +from .vqa import ReportVQA, VQAAcc + +__all__ = [ + 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision', + 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric', + 'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption', + 'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave', + 'RetrievalAveragePrecision', 'ShapeBiasMetric', 'ANLS' +] diff --git a/mmpretrain/evaluation/metrics/caption.py b/mmpretrain/evaluation/metrics/caption.py new file mode 100644 index 0000000..c4bffab --- /dev/null +++ b/mmpretrain/evaluation/metrics/caption.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import tempfile +from typing import List, Optional + +from mmengine.evaluator import BaseMetric +from mmengine.utils import track_iter_progress + +from mmpretrain.registry import METRICS +from mmpretrain.utils import require + +try: + from pycocoevalcap.eval import COCOEvalCap + from pycocotools.coco import COCO +except ImportError: + COCOEvalCap = None + COCO = None + + +@METRICS.register_module() +class COCOCaption(BaseMetric): + """Coco Caption evaluation wrapper. + + Save the generated captions and transform into coco format. + Calling COCO API for caption metrics. + + Args: + ann_file (str): the path for the COCO format caption ground truth + json file, load for evaluations. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + + @require('pycocoevalcap') + def __init__(self, + ann_file: str, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + self.ann_file = ann_file + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + + result['caption'] = data_sample.get('pred_caption') + result['image_id'] = int(data_sample.get('image_id')) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + + with tempfile.TemporaryDirectory() as temp_dir: + + eval_result_file = save_result( + result=results, + result_dir=temp_dir, + filename='m4-caption_pred', + remove_duplicate='image_id', + ) + + coco_val = coco_caption_eval(eval_result_file, self.ann_file) + + return coco_val + + +def save_result(result, result_dir, filename, remove_duplicate=''): + """Saving predictions as json file for evaluation.""" + + # combine results from all processes + result_new = [] + + if remove_duplicate: + result_new = [] + id_list = [] + for res in track_iter_progress(result): + if res[remove_duplicate] not in id_list: + id_list.append(res[remove_duplicate]) + result_new.append(res) + result = result_new + + final_result_file_url = os.path.join(result_dir, '%s.json' % filename) + print(f'result file saved to {final_result_file_url}') + json.dump(result, open(final_result_file_url, 'w')) + + return final_result_file_url + + +def coco_caption_eval(results_file, ann_file): + """Evaluation between gt json and prediction json files.""" + # create coco object and coco_result object + coco = COCO(ann_file) + coco_result = coco.loadRes(results_file) + + # create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco, coco_result) + + # make sure the image ids are the same + coco_eval.params['image_id'] = coco_result.getImgIds() + + # This will take some times at the first run + coco_eval.evaluate() + + # print output evaluation scores + for metric, score in coco_eval.eval.items(): + print(f'{metric}: {score:.3f}') + + return coco_eval.eval diff --git a/mmpretrain/evaluation/metrics/gqa.py b/mmpretrain/evaluation/metrics/gqa.py new file mode 100644 index 0000000..d5e8b07 --- /dev/null +++ b/mmpretrain/evaluation/metrics/gqa.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from mmengine.evaluator import BaseMetric + +from mmpretrain.evaluation.metrics.vqa import (_process_digit_article, + _process_punctuation) +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class GQAAcc(BaseMetric): + """GQA Acc metric. + + Compute GQA accuracy. + + Args: + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'GQA' + + def __init__(self, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch, data_samples) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + gt_answer = sample.get('gt_answer') + result = { + 'pred_answer': sample.get('pred_answer'), + 'gt_answer': gt_answer + } + + self.results.append(result) + + def compute_metrics(self, results: List) -> dict: + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + acc = [] + for result in results: + pred_answer = self._process_answer(result['pred_answer']) + gt_answer = self._process_answer(result['gt_answer']) + gqa_acc = 1 if pred_answer == gt_answer else 0 + acc.append(gqa_acc) + + accuracy = sum(acc) / len(acc) + + metrics = {'acc': accuracy} + return metrics + + def _process_answer(self, answer) -> str: + answer = _process_punctuation(answer) + answer = _process_digit_article(answer) + return answer diff --git a/mmpretrain/evaluation/metrics/multi_label.py b/mmpretrain/evaluation/metrics/multi_label.py new file mode 100644 index 0000000..bd91aac --- /dev/null +++ b/mmpretrain/evaluation/metrics/multi_label.py @@ -0,0 +1,599 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import numpy as np +import torch +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmpretrain.registry import METRICS +from mmpretrain.structures import label_to_onehot +from .single_label import _precision_recall_f1_support, to_tensor + + +@METRICS.register_module() +class MultiLabelMetric(BaseMetric): + r"""A collection of precision, recall, f1-score and support for + multi-label tasks. + + The collection of metrics is for single-label multi-class classification. + And all these metrics are based on the confusion matrix of every category: + + .. image:: ../../_static/image/confusion-matrix.png + :width: 60% + :align: center + + All metrics can be formulated use variables above: + + **Precision** is the fraction of correct predictions in all predictions: + + .. math:: + \text{Precision} = \frac{TP}{TP+FP} + + **Recall** is the fraction of correct predictions in all targets: + + .. math:: + \text{Recall} = \frac{TP}{TP+FN} + + **F1-score** is the harmonic mean of the precision and recall: + + .. math:: + \text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}} + + **Support** is the number of samples: + + .. math:: + \text{Support} = TP + TN + FN + FP + + Args: + thr (float, optional): Predictions with scores under the threshold + are considered as negative. If None, the ``topk`` predictions will + be considered as positive. If the ``topk`` is also None, use + ``thr=0.5`` as default. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. If None, use ``thr`` to determine positive + predictions. If both ``thr`` and ``topk`` are not None, use + ``thr``. Defaults to None. + items (Sequence[str]): The detailed metric items to evaluate, select + from "precision", "recall", "f1-score" and "support". + Defaults to ``('precision', 'recall', 'f1-score')``. + average (str | None): How to calculate the final metrics from the + confusion matrix of every category. It supports three modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories and + calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output directly. + + Defaults to "macro". + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import MultiLabelMetric + >>> # ------ The Basic Usage for category indices labels ------- + >>> y_pred = [[0], [1], [0, 1], [3]] + >>> y_true = [[0, 3], [0, 2], [1], [3]] + >>> # Output precision, recall, f1-score and support + >>> MultiLabelMetric.calculate( + ... y_pred, y_true, pred_indices=True, target_indices=True, num_classes=4) + (tensor(50.), tensor(50.), tensor(45.8333), tensor(6)) + >>> # ----------- The Basic Usage for one-hot labels ----------- + >>> y_pred = torch.tensor([[1, 1, 0, 0], + ... [1, 1, 0, 0], + ... [0, 0, 1, 0], + ... [0, 1, 0, 0], + ... [0, 1, 0, 0]]) + >>> y_true = torch.Tensor([[1, 1, 0, 0], + ... [0, 0, 1, 0], + ... [1, 1, 1, 0], + ... [1, 0, 0, 0], + ... [1, 0, 0, 0]]) + >>> MultiLabelMetric.calculate(y_pred, y_true) + (tensor(43.7500), tensor(31.2500), tensor(33.3333), tensor(8)) + >>> # --------- The Basic Usage for one-hot pred scores --------- + >>> y_pred = torch.rand(y_true.size()) + >>> y_pred + tensor([[0.4575, 0.7335, 0.3934, 0.2572], + [0.1318, 0.1004, 0.8248, 0.6448], + [0.8349, 0.6294, 0.7896, 0.2061], + [0.4037, 0.7308, 0.6713, 0.8374], + [0.3779, 0.4836, 0.0313, 0.0067]]) + >>> # Calculate with different threshold. + >>> MultiLabelMetric.calculate(y_pred, y_true, thr=0.1) + (tensor(42.5000), tensor(75.), tensor(53.1746), tensor(8)) + >>> # Calculate with topk. + >>> MultiLabelMetric.calculate(y_pred, y_true, topk=1) + (tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8)) + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_sampels = [ + ... DataSample().set_pred_score(pred).set_gt_score(gt) + ... for pred, gt in zip(torch.rand(1000, 5), torch.randint(0, 2, (1000, 5)))] + >>> evaluator = Evaluator(metrics=MultiLabelMetric(thr=0.5)) + >>> evaluator.process(data_sampels) + >>> evaluator.evaluate(1000) + { + 'multi-label/precision': 50.72898037055408, + 'multi-label/recall': 50.06836461357571, + 'multi-label/f1-score': 50.384466955258475 + } + >>> # Evaluate on each class by using topk strategy + >>> evaluator = Evaluator(metrics=MultiLabelMetric(topk=1, average=None)) + >>> evaluator.process(data_sampels) + >>> evaluator.evaluate(1000) + { + 'multi-label/precision_top1_classwise': [48.22, 50.54, 50.99, 44.18, 52.5], + 'multi-label/recall_top1_classwise': [18.92, 19.22, 19.92, 20.0, 20.27], + 'multi-label/f1-score_top1_classwise': [27.18, 27.85, 28.65, 27.54, 29.25] + } + """ # noqa: E501 + default_prefix: Optional[str] = 'multi-label' + + def __init__(self, + thr: Optional[float] = None, + topk: Optional[int] = None, + items: Sequence[str] = ('precision', 'recall', 'f1-score'), + average: Optional[str] = 'macro', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + + logger = MMLogger.get_current_instance() + if thr is None and topk is None: + thr = 0.5 + logger.warning('Neither thr nor k is given, set thr as 0.5 by ' + 'default.') + elif thr is not None and topk is not None: + logger.warning('Both thr and topk are given, ' + 'use threshold in favor of top-k.') + + self.thr = thr + self.topk = topk + self.average = average + + for item in items: + assert item in ['precision', 'recall', 'f1-score', 'support'], \ + f'The metric {item} is not supported by `SingleLabelMetric`,' \ + ' please choose from "precision", "recall", "f1-score" and ' \ + '"support".' + self.items = tuple(items) + + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + + result['pred_score'] = data_sample['pred_score'].clone() + num_classes = result['pred_score'].size()[-1] + + if 'gt_score' in data_sample: + result['gt_score'] = data_sample['gt_score'].clone() + else: + result['gt_score'] = label_to_onehot(data_sample['gt_label'], + num_classes) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. `self.results` + # are a list of results from multiple batch, while the input `results` + # are the collected results. + metrics = {} + + target = torch.stack([res['gt_score'] for res in results]) + pred = torch.stack([res['pred_score'] for res in results]) + + metric_res = self.calculate( + pred, + target, + pred_indices=False, + target_indices=False, + average=self.average, + thr=self.thr, + topk=self.topk) + + def pack_results(precision, recall, f1_score, support): + single_metrics = {} + if 'precision' in self.items: + single_metrics['precision'] = precision + if 'recall' in self.items: + single_metrics['recall'] = recall + if 'f1-score' in self.items: + single_metrics['f1-score'] = f1_score + if 'support' in self.items: + single_metrics['support'] = support + return single_metrics + + if self.thr: + suffix = '' if self.thr == 0.5 else f'_thr-{self.thr:.2f}' + for k, v in pack_results(*metric_res).items(): + metrics[k + suffix] = v + else: + for k, v in pack_results(*metric_res).items(): + metrics[k + f'_top{self.topk}'] = v + + result_metrics = dict() + for k, v in metrics.items(): + if self.average is None: + result_metrics[k + '_classwise'] = v.detach().cpu().tolist() + elif self.average == 'macro': + result_metrics[k] = v.item() + else: + result_metrics[k + f'_{self.average}'] = v.item() + return result_metrics + + @staticmethod + def calculate( + pred: Union[torch.Tensor, np.ndarray, Sequence], + target: Union[torch.Tensor, np.ndarray, Sequence], + pred_indices: bool = False, + target_indices: bool = False, + average: Optional[str] = 'macro', + thr: Optional[float] = None, + topk: Optional[int] = None, + num_classes: Optional[int] = None + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Calculate the precision, recall, f1-score. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, num_classes)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, num_classes)`` or a sequence of index/onehot + format labels. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. If True, ``num_classes`` must be set. + Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. If True, ``num_classes`` must be set. + Defaults to False. + average (str | None): How to calculate the final metrics from + the confusion matrix of every category. It supports three + modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories + and calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output + directly. + + Defaults to "macro". + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + num_classes (Optional, int): The number of classes. If the ``pred`` + is indices instead of onehot, this argument is required. + Defaults to None. + + Returns: + Tuple: The tuple contains precision, recall and f1-score. + And the type of each item is: + + - torch.Tensor: A tensor for each metric. The shape is (1, ) if + ``average`` is not None, and (C, ) if ``average`` is None. + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. + """ + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specicy from {average_options}.' + + def _format_label(label, is_indices): + """format various label to torch.Tensor.""" + if isinstance(label, np.ndarray): + assert label.ndim == 2, 'The shape `pred` and `target` ' \ + 'array must be (N, num_classes).' + label = torch.from_numpy(label) + elif isinstance(label, torch.Tensor): + assert label.ndim == 2, 'The shape `pred` and `target` ' \ + 'tensor must be (N, num_classes).' + elif isinstance(label, Sequence): + if is_indices: + assert num_classes is not None, 'For index-type labels, ' \ + 'please specify `num_classes`.' + label = torch.stack([ + label_to_onehot(indices, num_classes) + for indices in label + ]) + else: + label = torch.stack( + [to_tensor(onehot) for onehot in label]) + else: + raise TypeError( + 'The `pred` and `target` must be type of torch.tensor or ' + f'np.ndarray or sequence but get {type(label)}.') + return label + + pred = _format_label(pred, pred_indices) + target = _format_label(target, target_indices).long() + + assert pred.shape == target.shape, \ + f"The size of pred ({pred.shape}) doesn't match "\ + f'the target ({target.shape}).' + + if num_classes is not None: + assert pred.size(1) == num_classes, \ + f'The shape of `pred` ({pred.shape}) '\ + f"doesn't match the num_classes ({num_classes})." + num_classes = pred.size(1) + + thr = 0.5 if (thr is None and topk is None) else thr + + if thr is not None: + # a label is predicted positive if larger than thr + pos_inds = (pred >= thr).long() + else: + # top-k labels will be predicted positive for any example + _, topk_indices = pred.topk(topk) + pos_inds = torch.zeros_like(pred).scatter_(1, topk_indices, 1) + pos_inds = pos_inds.long() + + return _precision_recall_f1_support(pos_inds, target, average) + + +def _average_precision(pred: torch.Tensor, + target: torch.Tensor) -> torch.Tensor: + r"""Calculate the average precision for a single class. + + AP summarizes a precision-recall curve as the weighted mean of maximum + precisions obtained for any r'>r, where r is the recall: + + .. math:: + \text{AP} = \sum_n (R_n - R_{n-1}) P_n + + Note that no approximation is involved since the curve is piecewise + constant. + + Args: + pred (torch.Tensor): The model prediction with shape + ``(N, num_classes)``. + target (torch.Tensor): The target of predictions with shape + ``(N, num_classes)``. + + Returns: + torch.Tensor: average precision result. + """ + assert pred.shape == target.shape, \ + f"The size of pred ({pred.shape}) doesn't match "\ + f'the target ({target.shape}).' + + # a small value for division by zero errors + eps = torch.finfo(torch.float32).eps + + # get rid of -1 target such as difficult sample + # that is not wanted in evaluation results. + valid_index = target > -1 + pred = pred[valid_index] + target = target[valid_index] + + # sort examples + sorted_pred_inds = torch.argsort(pred, dim=0, descending=True) + sorted_target = target[sorted_pred_inds] + + # get indexes when gt_true is positive + pos_inds = sorted_target == 1 + + # Calculate cumulative tp case numbers + tps = torch.cumsum(pos_inds, 0) + total_pos = tps[-1].item() # the last of tensor may change later + + # Calculate cumulative tp&fp(pred_poss) case numbers + pred_pos_nums = torch.arange(1, len(sorted_target) + 1).to(pred.device) + pred_pos_nums[pred_pos_nums < eps] = eps + + tps[torch.logical_not(pos_inds)] = 0 + precision = tps / pred_pos_nums.float() + ap = torch.sum(precision, 0) / max(total_pos, eps) + return ap + + +@METRICS.register_module() +class AveragePrecision(BaseMetric): + r"""Calculate the average precision with respect of classes. + + AveragePrecision (AP) summarizes a precision-recall curve as the weighted + mean of maximum precisions obtained for any r'>r, where r is the recall: + + .. math:: + \text{AP} = \sum_n (R_n - R_{n-1}) P_n + + Note that no approximation is involved since the curve is piecewise + constant. + + Args: + average (str | None): How to calculate the final metrics from + every category. It supports two modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. The result of this mode + is also called **mAP**. + - `None`: Calculate metrics of every category and output directly. + + Defaults to "macro". + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + References + ---------- + 1. `Wikipedia entry for the Average precision + `_ + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import AveragePrecision + >>> # --------- The Basic Usage for one-hot pred scores --------- + >>> y_pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2], + ... [0.1, 0.2, 0.2, 0.1], + ... [0.7, 0.5, 0.9, 0.3], + ... [0.8, 0.1, 0.1, 0.2]]) + >>> y_true = torch.Tensor([[1, 1, 0, 0], + ... [0, 1, 0, 0], + ... [0, 0, 1, 0], + ... [1, 0, 0, 0]]) + >>> AveragePrecision.calculate(y_pred, y_true) + tensor(70.833) + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_pred_score(i).set_gt_score(j) + ... for i, j in zip(y_pred, y_true) + ... ] + >>> evaluator = Evaluator(metrics=AveragePrecision()) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(5) + {'multi-label/mAP': 70.83333587646484} + >>> # Evaluate on each class + >>> evaluator = Evaluator(metrics=AveragePrecision(average=None)) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(5) + {'multi-label/AP_classwise': [100., 83.33, 100., 0.]} + """ + default_prefix: Optional[str] = 'multi-label' + + def __init__(self, + average: Optional[str] = 'macro', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.average = average + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + + result['pred_score'] = data_sample['pred_score'].clone() + num_classes = result['pred_score'].size()[-1] + + if 'gt_score' in data_sample: + result['gt_score'] = data_sample['gt_score'].clone() + else: + result['gt_score'] = label_to_onehot(data_sample['gt_label'], + num_classes) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. `self.results` + # are a list of results from multiple batch, while the input `results` + # are the collected results. + + # concat + target = torch.stack([res['gt_score'] for res in results]) + pred = torch.stack([res['pred_score'] for res in results]) + + ap = self.calculate(pred, target, self.average) + + result_metrics = dict() + + if self.average is None: + result_metrics['AP_classwise'] = ap.detach().cpu().tolist() + else: + result_metrics['mAP'] = ap.item() + + return result_metrics + + @staticmethod + def calculate(pred: Union[torch.Tensor, np.ndarray], + target: Union[torch.Tensor, np.ndarray], + average: Optional[str] = 'macro') -> torch.Tensor: + r"""Calculate the average precision for a single class. + + Args: + pred (torch.Tensor | np.ndarray): The model predictions with + shape ``(N, num_classes)``. + target (torch.Tensor | np.ndarray): The target of predictions + with shape ``(N, num_classes)``. + average (str | None): The average method. It supports two modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. The result of this mode + is also called mAP. + - `None`: Calculate metrics of every category and output + directly. + + Defaults to "macro". + + Returns: + torch.Tensor: the average precision of all classes. + """ + average_options = ['macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specicy from {average_options}.' + + pred = to_tensor(pred) + target = to_tensor(target) + assert pred.ndim == 2 and pred.shape == target.shape, \ + 'Both `pred` and `target` should have shape `(N, num_classes)`.' + + num_classes = pred.shape[1] + ap = pred.new_zeros(num_classes) + for k in range(num_classes): + ap[k] = _average_precision(pred[:, k], target[:, k]) + if average == 'macro': + return ap.mean() * 100.0 + else: + return ap * 100 diff --git a/mmpretrain/evaluation/metrics/multi_task.py b/mmpretrain/evaluation/metrics/multi_task.py new file mode 100644 index 0000000..0e6af76 --- /dev/null +++ b/mmpretrain/evaluation/metrics/multi_task.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Sequence + +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class MultiTasksMetric(BaseMetric): + """Metrics for MultiTask + Args: + task_metrics(dict): a dictionary in the keys are the names of the tasks + and the values is a list of the metric corresponds to this task + Examples: + >>> import torch + >>> from mmpretrain.evaluation import MultiTasksMetric + # -------------------- The Basic Usage -------------------- + >>>task_metrics = { + 'task0': [dict(type='Accuracy', topk=(1, ))], + 'task1': [dict(type='Accuracy', topk=(1, 3))] + } + >>>pred = [{ + 'pred_task': { + 'task0': torch.tensor([0.7, 0.0, 0.3]), + 'task1': torch.tensor([0.5, 0.2, 0.3]) + }, + 'gt_task': { + 'task0': torch.tensor(0), + 'task1': torch.tensor(2) + } + }, { + 'pred_task': { + 'task0': torch.tensor([0.0, 0.0, 1.0]), + 'task1': torch.tensor([0.0, 0.0, 1.0]) + }, + 'gt_task': { + 'task0': torch.tensor(2), + 'task1': torch.tensor(2) + } + }] + >>>metric = MultiTasksMetric(task_metrics) + >>>metric.process(None, pred) + >>>results = metric.evaluate(2) + results = { + 'task0_accuracy/top1': 100.0, + 'task1_accuracy/top1': 50.0, + 'task1_accuracy/top3': 100.0 + } + """ + + def __init__(self, + task_metrics: Dict, + collect_device: str = 'cpu') -> None: + self.task_metrics = task_metrics + super().__init__(collect_device=collect_device) + + self._metrics = {} + for task_name in self.task_metrics.keys(): + self._metrics[task_name] = [] + for metric in self.task_metrics[task_name]: + self._metrics[task_name].append(METRICS.build(metric)) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for task_name in self.task_metrics.keys(): + filtered_data_samples = [] + for data_sample in data_samples: + eval_mask = data_sample[task_name]['eval_mask'] + if eval_mask: + filtered_data_samples.append(data_sample[task_name]) + for metric in self._metrics[task_name]: + metric.process(data_batch, filtered_data_samples) + + def compute_metrics(self, results: list) -> dict: + raise NotImplementedError( + 'compute metrics should not be used here directly') + + def evaluate(self, size): + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. When batch + size > 1, the dataloader may pad some data samples to make + sure all ranks have the same length of dataset slice. The + ``collect_results`` function will drop the padded data based on + this size. + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are + "{task_name}_{metric_name}" , and the values + are corresponding results. + """ + metrics = {} + for task_name in self._metrics: + for metric in self._metrics[task_name]: + name = metric.__class__.__name__ + if name == 'MultiTasksMetric' or metric.results: + results = metric.evaluate(size) + else: + results = {metric.__class__.__name__: 0} + for key in results: + name = f'{task_name}_{key}' + if name in results: + """Inspired from https://github.com/open- + mmlab/mmengine/ bl ob/ed20a9cba52ceb371f7c825131636b9e2 + 747172e/mmengine/evalua tor/evaluator.py#L84-L87.""" + raise ValueError( + 'There are multiple metric results with the same' + f'metric name {name}. Please make sure all metrics' + 'have different prefixes.') + metrics[name] = results[key] + return metrics diff --git a/mmpretrain/evaluation/metrics/nocaps.py b/mmpretrain/evaluation/metrics/nocaps.py new file mode 100644 index 0000000..e8e1d06 --- /dev/null +++ b/mmpretrain/evaluation/metrics/nocaps.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import mmengine + +from mmpretrain.registry import METRICS +from mmpretrain.utils import require +from .caption import COCOCaption, save_result + +try: + from pycocoevalcap.eval import COCOEvalCap + from pycocotools.coco import COCO +except ImportError: + COCOEvalCap = None + COCO = None + + +@METRICS.register_module() +class NocapsSave(COCOCaption): + """Nocaps evaluation wrapper. + + Save the generated captions and transform into coco format. + The dumped file can be submitted to the official evluation system. + + Args: + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + + @require('pycocoevalcap') + def __init__(self, + save_dir: str = './', + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super(COCOCaption, self).__init__( + collect_device=collect_device, prefix=prefix) + self.save_dir = save_dir + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + """ + mmengine.mkdir_or_exist(self.save_dir) + save_result( + result=results, + result_dir=self.save_dir, + filename='nocap_pred', + remove_duplicate='image_id', + ) + + return dict() diff --git a/mmpretrain/evaluation/metrics/retrieval.py b/mmpretrain/evaluation/metrics/retrieval.py new file mode 100644 index 0000000..9813486 --- /dev/null +++ b/mmpretrain/evaluation/metrics/retrieval.py @@ -0,0 +1,445 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import mmengine +import numpy as np +import torch +from mmengine.evaluator import BaseMetric +from mmengine.utils import is_seq_of + +from mmpretrain.registry import METRICS +from mmpretrain.structures import label_to_onehot +from .single_label import to_tensor + + +@METRICS.register_module() +class RetrievalRecall(BaseMetric): + r"""Recall evaluation metric for image retrieval. + + Args: + topk (int | Sequence[int]): If the ground truth label matches one of + the best **k** predictions, the sample will be regard as a positive + prediction. If the parameter is a tuple, all of top-k recall will + be calculated and outputted together. Defaults to 1. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + Use in the code: + + >>> import torch + >>> from mmpretrain.evaluation import RetrievalRecall + >>> # -------------------- The Basic Usage -------------------- + >>> y_pred = [[0], [1], [2], [3]] + >>> y_true = [[0, 1], [2], [1], [0, 3]] + >>> RetrievalRecall.calculate( + >>> y_pred, y_true, topk=1, pred_indices=True, target_indices=True) + [tensor([50.])] + >>> # Calculate the recall@1 and recall@5 for non-indices input. + >>> y_score = torch.rand((1000, 10)) + >>> import torch.nn.functional as F + >>> y_true = F.one_hot(torch.arange(0, 1000) % 10, num_classes=10) + >>> RetrievalRecall.calculate(y_score, y_true, topk=(1, 5)) + [tensor(9.3000), tensor(48.4000)] + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_gt_label([0, 1]).set_pred_score( + ... torch.rand(10)) + ... for i in range(1000) + ... ] + >>> evaluator = Evaluator(metrics=RetrievalRecall(topk=(1, 5))) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + {'retrieval/Recall@1': 20.700000762939453, + 'retrieval/Recall@5': 78.5999984741211} + + Use in OpenMMLab configs: + + .. code:: python + + val_evaluator = dict(type='RetrievalRecall', topk=(1, 5)) + test_evaluator = val_evaluator + """ + default_prefix: Optional[str] = 'retrieval' + + def __init__(self, + topk: Union[int, Sequence[int]], + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + topk = (topk, ) if isinstance(topk, int) else topk + + for k in topk: + if k <= 0: + raise ValueError('`topk` must be a ingter larger than 0 ' + 'or seq of ingter larger than 0.') + + self.topk = topk + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]): + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_score = data_sample['pred_score'].clone() + gt_label = data_sample['gt_label'] + + if 'gt_score' in data_sample: + target = data_sample.get('gt_score').clone() + else: + num_classes = pred_score.size()[-1] + target = label_to_onehot(gt_label, num_classes) + + # Because the retrieval output logit vector will be much larger + # compared to the normal classification, to save resources, the + # evaluation results are computed each batch here and then reduce + # all results at the end. + result = RetrievalRecall.calculate( + pred_score.unsqueeze(0), target.unsqueeze(0), topk=self.topk) + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + result_metrics = dict() + for i, k in enumerate(self.topk): + recall_at_k = sum([r[i].item() for r in results]) / len(results) + result_metrics[f'Recall@{k}'] = recall_at_k + + return result_metrics + + @staticmethod + def calculate(pred: Union[np.ndarray, torch.Tensor], + target: Union[np.ndarray, torch.Tensor], + topk: Union[int, Sequence[int]], + pred_indices: (bool) = False, + target_indices: (bool) = False) -> float: + """Calculate the average recall. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + topk (int, Sequence[int]): Predictions with the k-th highest + scores are considered as positive. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. Defaults to False. + + Returns: + List[float]: the average recalls. + """ + topk = (topk, ) if isinstance(topk, int) else topk + for k in topk: + if k <= 0: + raise ValueError('`topk` must be a ingter larger than 0 ' + 'or seq of ingter larger than 0.') + + max_keep = max(topk) + pred = _format_pred(pred, max_keep, pred_indices) + target = _format_target(target, target_indices) + + assert len(pred) == len(target), ( + f'Length of `pred`({len(pred)}) and `target` ({len(target)}) ' + f'must be the same.') + + num_samples = len(pred) + results = [] + for k in topk: + recalls = torch.zeros(num_samples) + for i, (sample_pred, + sample_target) in enumerate(zip(pred, target)): + sample_pred = np.array(to_tensor(sample_pred).cpu()) + sample_target = np.array(to_tensor(sample_target).cpu()) + recalls[i] = int(np.in1d(sample_pred[:k], sample_target).max()) + results.append(recalls.mean() * 100) + return results + + +@METRICS.register_module() +class RetrievalAveragePrecision(BaseMetric): + r"""Calculate the average precision for image retrieval. + + Args: + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. + mode (str, optional): The mode to calculate AP, choose from + 'IR'(information retrieval) and 'integrate'. Defaults to 'IR'. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Note: + If the ``mode`` set to 'IR', use the stanford AP calculation of + information retrieval as in wikipedia page[1]; if set to 'integrate', + the method implemented integrates over the precision-recall curve + by averaging two adjacent precision points, then multiplying by the + recall step like mAP in Detection task. This is the convention for + the Revisited Oxford/Paris datasets[2]. + + References: + [1] `Wikipedia entry for the Average precision `_ + + [2] `The Oxford Buildings Dataset + `_ + + Examples: + Use in code: + + >>> import torch + >>> import numpy as np + >>> from mmcls.evaluation import RetrievalAveragePrecision + >>> # using index format inputs + >>> pred = [ torch.Tensor([idx for idx in range(100)]) ] * 3 + >>> target = [[0, 3, 6, 8, 35], [1, 2, 54, 105], [2, 42, 205]] + >>> RetrievalAveragePrecision.calculate(pred, target, 10, True, True) + 29.246031746031747 + >>> # using tensor format inputs + >>> pred = np.array([np.linspace(0.95, 0.05, 10)] * 2) + >>> target = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1]] * 2) + >>> RetrievalAveragePrecision.calculate(pred, target, 10) + 62.222222222222214 + + Use in OpenMMLab config files: + + .. code:: python + + val_evaluator = dict(type='RetrievalAveragePrecision', topk=100) + test_evaluator = val_evaluator + """ + + default_prefix: Optional[str] = 'retrieval' + + def __init__(self, + topk: Optional[int] = None, + mode: Optional[str] = 'IR', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + if topk is None or (isinstance(topk, int) and topk <= 0): + raise ValueError('`topk` must be a ingter larger than 0.') + + mode_options = ['IR', 'integrate'] + assert mode in mode_options, \ + f'Invalid `mode` argument, please specify from {mode_options}.' + + self.topk = topk + self.mode = mode + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]): + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_score = data_sample.get('pred_score').clone() + + if 'gt_score' in data_sample: + target = data_sample.get('gt_score').clone() + else: + gt_label = data_sample.get('gt_label') + num_classes = pred_score.size()[-1] + target = label_to_onehot(gt_label, num_classes) + + # Because the retrieval output logit vector will be much larger + # compared to the normal classification, to save resources, the + # evaluation results are computed each batch here and then reduce + # all results at the end. + result = RetrievalAveragePrecision.calculate( + pred_score.unsqueeze(0), + target.unsqueeze(0), + self.topk, + mode=self.mode) + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + result_metrics = dict() + result_metrics[f'mAP@{self.topk}'] = np.mean(self.results).item() + + return result_metrics + + @staticmethod + def calculate(pred: Union[np.ndarray, torch.Tensor], + target: Union[np.ndarray, torch.Tensor], + topk: Optional[int] = None, + pred_indices: (bool) = False, + target_indices: (bool) = False, + mode: str = 'IR') -> float: + """Calculate the average precision. + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + topk (int, optional): Predictions with the k-th highest scores + are considered as positive. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. Defaults to False. + mode (Optional[str]): The mode to calculate AP, choose from + 'IR'(information retrieval) and 'integrate'. Defaults to 'IR'. + + Note: + If the ``mode`` set to 'IR', use the stanford AP calculation of + information retrieval as in wikipedia page; if set to 'integrate', + the method implemented integrates over the precision-recall curve + by averaging two adjacent precision points, then multiplying by the + recall step like mAP in Detection task. This is the convention for + the Revisited Oxford/Paris datasets. + + Returns: + float: the average precision of the query image. + + References: + [1] `Wikipedia entry for Average precision(information_retrieval) + `_ + [2] `The Oxford Buildings Dataset 0 else 1 + cur_precision = (i + 1) / (rank + 1) + prediction = (old_precision + cur_precision) / 2 + ap += prediction + ap = ap / len(target) + + return ap * 100 + + +def _format_pred(label, topk=None, is_indices=False): + """format various label to List[indices].""" + if is_indices: + assert isinstance(label, Sequence), \ + '`pred` must be Sequence of indices when' \ + f' `pred_indices` set to True, but get {type(label)}' + for i, sample_pred in enumerate(label): + assert is_seq_of(sample_pred, int) or isinstance( + sample_pred, (np.ndarray, torch.Tensor)), \ + '`pred` should be Sequence of indices when `pred_indices`' \ + f'set to True. but pred[{i}] is {sample_pred}' + if topk: + label[i] = sample_pred[:min(topk, len(sample_pred))] + return label + if isinstance(label, np.ndarray): + label = torch.from_numpy(label) + elif not isinstance(label, torch.Tensor): + raise TypeError(f'The pred must be type of torch.tensor, ' + f'np.ndarray or Sequence but get {type(label)}.') + topk = topk if topk else label.size()[-1] + _, indices = label.topk(topk) + return indices + + +def _format_target(label, is_indices=False): + """format various label to List[indices].""" + if is_indices: + assert isinstance(label, Sequence), \ + '`target` must be Sequence of indices when' \ + f' `target_indices` set to True, but get {type(label)}' + for i, sample_gt in enumerate(label): + assert is_seq_of(sample_gt, int) or isinstance( + sample_gt, (np.ndarray, torch.Tensor)), \ + '`target` should be Sequence of indices when ' \ + f'`target_indices` set to True. but target[{i}] is {sample_gt}' + return label + + if isinstance(label, np.ndarray): + label = torch.from_numpy(label) + elif isinstance(label, Sequence) and not mmengine.is_str(label): + label = torch.tensor(label) + elif not isinstance(label, torch.Tensor): + raise TypeError(f'The pred must be type of torch.tensor, ' + f'np.ndarray or Sequence but get {type(label)}.') + + indices = [sample_gt.nonzero().squeeze(-1) for sample_gt in label] + return indices diff --git a/mmpretrain/evaluation/metrics/scienceqa.py b/mmpretrain/evaluation/metrics/scienceqa.py new file mode 100644 index 0000000..ebf01c7 --- /dev/null +++ b/mmpretrain/evaluation/metrics/scienceqa.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import List, Optional + +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +def get_pred_idx(prediction: str, choices: List[str], + options: List[str]) -> int: # noqa + """Get the index (e.g. 2) from the prediction (e.g. 'C') + + Args: + prediction (str): The prediction from the model, + from ['A', 'B', 'C', 'D', 'E'] + choices (List(str)): The choices for the question, + from ['A', 'B', 'C', 'D', 'E'] + options (List(str)): The options for the question, + from ['A', 'B', 'C', 'D', 'E'] + + Returns: + int: The index of the prediction, from [0, 1, 2, 3, 4] + """ + if prediction in options[:len(choices)]: + return options.index(prediction) + else: + return random.choice(range(len(choices))) + + +@METRICS.register_module() +class ScienceQAMetric(BaseMetric): + """Evaluation Metric for ScienceQA. + + Args: + options (List(str)): Options for each question. Defaults to + ["A", "B", "C", "D", "E"]. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + + def __init__(self, + options: List[str] = ['A', 'B', 'C', 'D', 'E'], + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.options = options + + def process(self, data_batch, data_samples) -> None: + """Process one batch of data samples. + + data_samples should contain the following keys: + 1. pred_answer (str): The prediction from the model, + from ['A', 'B', 'C', 'D', 'E'] + 2. choices (List(str)): The choices for the question, + from ['A', 'B', 'C', 'D', 'E'] + 3. grade (int): The grade for the question, from grade1 to grade12 + 4. subject (str): The subject for the question, from + ['natural science', 'social science', 'language science'] + 5. answer (str): The answer for the question, from + ['A', 'B', 'C', 'D', 'E'] + 6. hint (str): The hint for the question + 7. has_image (bool): Whether or not the question has image + + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + choices = data_sample.get('choices') + result['prediction'] = get_pred_idx( + data_sample.get('pred_answer'), choices, self.options) + result['grade'] = data_sample.get('grade') + result['subject'] = data_sample.get('subject') + result['answer'] = data_sample.get('gt_answer') + hint = data_sample.get('hint') + has_image = data_sample.get('has_image', False) + result['no_context'] = True if not has_image and len( + hint) == 0 else False # noqa + result['has_text'] = True if len(hint) > 0 else False + result['has_image'] = has_image + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List) -> dict: + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + metrics = dict() + + all_acc = [] + acc_natural = [] + acc_social = [] + acc_language = [] + acc_has_text = [] + acc_has_image = [] + acc_no_context = [] + acc_grade_1_6 = [] + acc_grade_7_12 = [] + + for result in results: + correct = result['prediction'] == result['answer'] + all_acc.append(correct) + # different subjects + if result['subject'] == 'natural science': + acc_natural.append(correct) + elif result['subject'] == 'social science': + acc_social.append(correct) + elif result['subject'] == 'language science': + acc_language.append(correct) + + # different context + if result['has_text']: + acc_has_text.append(correct) + elif result['has_image']: + acc_has_image.append(correct) + elif result['no_context']: + acc_no_context.append(correct) + + # different grade + if result['grade'] in [ + 'grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6' + ]: + acc_grade_1_6.append(correct) + elif result['grade'] in [ + 'grade7', 'grade8', 'grade9', 'grade10', 'grade11', + 'grade12' + ]: + acc_grade_7_12.append(correct) + + metrics['all_acc'] = sum(all_acc) / len(all_acc) + if len(acc_natural) > 0: + metrics['acc_natural'] = sum(acc_natural) / len(acc_natural) + if len(acc_social) > 0: + metrics['acc_social'] = sum(acc_social) / len(acc_social) + if len(acc_language) > 0: + metrics['acc_language'] = sum(acc_language) / len(acc_language) + if len(acc_has_text) > 0: + metrics['acc_has_text'] = sum(acc_has_text) / len(acc_has_text) + if len(acc_has_image) > 0: + metrics['acc_has_image'] = sum(acc_has_image) / len(acc_has_image) + if len(acc_no_context) > 0: + metrics['acc_no_context'] = sum(acc_no_context) / len( + acc_no_context) + if len(acc_grade_1_6) > 0: + metrics['acc_grade_1_6'] = sum(acc_grade_1_6) / len(acc_grade_1_6) + if len(acc_grade_7_12) > 0: + metrics['acc_grade_7_12'] = sum(acc_grade_7_12) / len( + acc_grade_7_12) + + return metrics diff --git a/mmpretrain/evaluation/metrics/shape_bias_label.py b/mmpretrain/evaluation/metrics/shape_bias_label.py new file mode 100644 index 0000000..27c80a3 --- /dev/null +++ b/mmpretrain/evaluation/metrics/shape_bias_label.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import csv +import os +import os.path as osp +from typing import List, Sequence + +import numpy as np +import torch +from mmengine.dist.utils import get_rank +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class ShapeBiasMetric(BaseMetric): + """Evaluate the model on ``cue_conflict`` dataset. + + This module will evaluate the model on an OOD dataset, cue_conflict, in + order to measure the shape bias of the model. In addition to compuate the + Top-1 accuracy, this module also generate a csv file to record the + detailed prediction results, such that this csv file can be used to + generate the shape bias curve. + + Args: + csv_dir (str): The directory to save the csv file. + model_name (str): The name of the csv file. Please note that the + model name should be an unique identifier. + dataset_name (str): The name of the dataset. Default: 'cue_conflict'. + """ + + # mapping several classes from ImageNet-1K to the same category + airplane_indices = [404] + bear_indices = [294, 295, 296, 297] + bicycle_indices = [444, 671] + bird_indices = [ + 8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 80, 81, 82, 83, + 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 127, 128, 129, + 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, + 145 + ] + boat_indices = [472, 554, 625, 814, 914] + bottle_indices = [440, 720, 737, 898, 899, 901, 907] + car_indices = [436, 511, 817] + cat_indices = [281, 282, 283, 284, 285, 286] + chair_indices = [423, 559, 765, 857] + clock_indices = [409, 530, 892] + dog_indices = [ + 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, + 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194, + 195, 196, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, + 224, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, + 239, 240, 241, 243, 244, 245, 246, 247, 248, 249, 250, 252, 253, 254, + 255, 256, 257, 259, 261, 262, 263, 265, 266, 267, 268 + ] + elephant_indices = [385, 386] + keyboard_indices = [508, 878] + knife_indices = [499] + oven_indices = [766] + truck_indices = [555, 569, 656, 675, 717, 734, 864, 867] + + def __init__(self, + csv_dir: str, + model_name: str, + dataset_name: str = 'cue_conflict', + **kwargs) -> None: + super().__init__(**kwargs) + + self.categories = sorted([ + 'knife', 'keyboard', 'elephant', 'bicycle', 'airplane', 'clock', + 'oven', 'chair', 'bear', 'boat', 'cat', 'bottle', 'truck', 'car', + 'bird', 'dog' + ]) + self.csv_dir = csv_dir + self.model_name = model_name + self.dataset_name = dataset_name + if get_rank() == 0: + self.csv_path = self.create_csv() + + def process(self, data_batch, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + if 'pred_score' in data_sample: + result['pred_score'] = data_sample['pred_score'].cpu() + else: + result['pred_label'] = data_sample['pred_label'].cpu() + result['gt_label'] = data_sample['gt_label'].cpu() + result['gt_category'] = data_sample['img_path'].split('/')[-2] + result['img_name'] = data_sample['img_path'].split('/')[-1] + + aggregated_category_probabilities = [] + # get the prediction for each category of current instance + for category in self.categories: + category_indices = getattr(self, f'{category}_indices') + category_probabilities = torch.gather( + result['pred_score'], 0, + torch.tensor(category_indices)).mean() + aggregated_category_probabilities.append( + category_probabilities) + # sort the probabilities in descending order + pred_indices = torch.stack(aggregated_category_probabilities + ).argsort(descending=True).numpy() + result['pred_category'] = np.take(self.categories, pred_indices) + + # Save the result to `self.results`. + self.results.append(result) + + def create_csv(self) -> str: + """Create a csv file to store the results.""" + session_name = 'session-1' + csv_path = osp.join( + self.csv_dir, self.dataset_name + '_' + self.model_name + '_' + + session_name + '.csv') + if osp.exists(csv_path): + os.remove(csv_path) + directory = osp.dirname(csv_path) + if not osp.exists(directory): + os.makedirs(directory, exist_ok=True) + with open(csv_path, 'w') as f: + writer = csv.writer(f) + writer.writerow([ + 'subj', 'session', 'trial', 'rt', 'object_response', + 'category', 'condition', 'imagename' + ]) + return csv_path + + def dump_results_to_csv(self, results: List[dict]) -> None: + """Dump the results to a csv file. + + Args: + results (List[dict]): A list of results. + """ + for i, result in enumerate(results): + img_name = result['img_name'] + category = result['gt_category'] + condition = 'NaN' + with open(self.csv_path, 'a') as f: + writer = csv.writer(f) + writer.writerow([ + self.model_name, 1, i + 1, 'NaN', + result['pred_category'][0], category, condition, img_name + ]) + + def compute_metrics(self, results: List[dict]) -> dict: + """Compute the metrics from the results. + + Args: + results (List[dict]): A list of results. + + Returns: + dict: A dict of metrics. + """ + if get_rank() == 0: + self.dump_results_to_csv(results) + metrics = dict() + metrics['accuracy/top1'] = np.mean([ + result['pred_category'][0] == result['gt_category'] + for result in results + ]) + + return metrics diff --git a/mmpretrain/evaluation/metrics/single_label.py b/mmpretrain/evaluation/metrics/single_label.py new file mode 100644 index 0000000..f9329b9 --- /dev/null +++ b/mmpretrain/evaluation/metrics/single_label.py @@ -0,0 +1,776 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product +from typing import List, Optional, Sequence, Union + +import mmengine +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +def to_tensor(value): + """Convert value to torch.Tensor.""" + if isinstance(value, np.ndarray): + value = torch.from_numpy(value) + elif isinstance(value, Sequence) and not mmengine.is_str(value): + value = torch.tensor(value) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'{type(value)} is not an available argument.') + return value + + +def _precision_recall_f1_support(pred_positive, gt_positive, average): + """calculate base classification task metrics, such as precision, recall, + f1_score, support.""" + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specify from {average_options}.' + + # ignore -1 target such as difficult sample that is not wanted + # in evaluation results. + # only for calculate multi-label without affecting single-label behavior + ignored_index = gt_positive == -1 + pred_positive[ignored_index] = 0 + gt_positive[ignored_index] = 0 + + class_correct = (pred_positive & gt_positive) + if average == 'micro': + tp_sum = class_correct.sum() + pred_sum = pred_positive.sum() + gt_sum = gt_positive.sum() + else: + tp_sum = class_correct.sum(0) + pred_sum = pred_positive.sum(0) + gt_sum = gt_positive.sum(0) + + precision = tp_sum / torch.clamp(pred_sum, min=1).float() * 100 + recall = tp_sum / torch.clamp(gt_sum, min=1).float() * 100 + f1_score = 2 * precision * recall / torch.clamp( + precision + recall, min=torch.finfo(torch.float32).eps) + if average in ['macro', 'micro']: + precision = precision.mean(0) + recall = recall.mean(0) + f1_score = f1_score.mean(0) + support = gt_sum.sum(0) + else: + support = gt_sum + return precision, recall, f1_score, support + + +@METRICS.register_module() +class Accuracy(BaseMetric): + r"""Accuracy evaluation metric. + + For either binary classification or multi-class classification, the + accuracy is the fraction of correct predictions in all predictions: + + .. math:: + + \text{Accuracy} = \frac{N_{\text{correct}}}{N_{\text{all}}} + + Args: + topk (int | Sequence[int]): If the ground truth label matches one of + the best **k** predictions, the sample will be regard as a positive + prediction. If the parameter is a tuple, all of top-k accuracy will + be calculated and outputted together. Defaults to 1. + thrs (Sequence[float | None] | float | None): If a float, predictions + with score lower than the threshold will be regard as the negative + prediction. If None, not apply threshold. If the parameter is a + tuple, accuracy based on all thresholds will be calculated and + outputted together. Defaults to 0. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import Accuracy + >>> # -------------------- The Basic Usage -------------------- + >>> y_pred = [0, 2, 1, 3] + >>> y_true = [0, 1, 2, 3] + >>> Accuracy.calculate(y_pred, y_true) + tensor([50.]) + >>> # Calculate the top1 and top5 accuracy. + >>> y_score = torch.rand((1000, 10)) + >>> y_true = torch.zeros((1000, )) + >>> Accuracy.calculate(y_score, y_true, topk=(1, 5)) + [[tensor([9.9000])], [tensor([51.5000])]] + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_gt_label(0).set_pred_score(torch.rand(10)) + ... for i in range(1000) + ... ] + >>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5))) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + { + 'accuracy/top1': 9.300000190734863, + 'accuracy/top5': 51.20000076293945 + } + """ + default_prefix: Optional[str] = 'accuracy' + + def __init__(self, + topk: Union[int, Sequence[int]] = (1, ), + thrs: Union[float, Sequence[Union[float, None]], None] = 0., + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + if isinstance(topk, int): + self.topk = (topk, ) + else: + self.topk = tuple(topk) + + if isinstance(thrs, float) or thrs is None: + self.thrs = (thrs, ) + else: + self.thrs = tuple(thrs) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + if 'pred_score' in data_sample: + result['pred_score'] = data_sample['pred_score'].cpu() + else: + result['pred_label'] = data_sample['pred_label'].cpu() + result['gt_label'] = data_sample['gt_label'].cpu() + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. + metrics = {} + + # concat + target = torch.cat([res['gt_label'] for res in results]) + if 'pred_score' in results[0]: + pred = torch.stack([res['pred_score'] for res in results]) + + try: + acc = self.calculate(pred, target, self.topk, self.thrs) + except ValueError as e: + # If the topk is invalid. + raise ValueError( + str(e) + ' Please check the `val_evaluator` and ' + '`test_evaluator` fields in your config file.') + + multi_thrs = len(self.thrs) > 1 + for i, k in enumerate(self.topk): + for j, thr in enumerate(self.thrs): + name = f'top{k}' + if multi_thrs: + name += '_no-thr' if thr is None else f'_thr-{thr:.2f}' + metrics[name] = acc[i][j].item() + else: + # If only label in the `pred_label`. + pred = torch.cat([res['pred_label'] for res in results]) + acc = self.calculate(pred, target, self.topk, self.thrs) + metrics['top1'] = acc.item() + + return metrics + + @staticmethod + def calculate( + pred: Union[torch.Tensor, np.ndarray, Sequence], + target: Union[torch.Tensor, np.ndarray, Sequence], + topk: Sequence[int] = (1, ), + thrs: Sequence[Union[float, None]] = (0., ), + ) -> Union[torch.Tensor, List[List[torch.Tensor]]]: + """Calculate the accuracy. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. It can be labels (N, ), or scores of every + class (N, C). + target (torch.Tensor | np.ndarray | Sequence): The target of + each prediction with shape (N, ). + thrs (Sequence[float | None]): Predictions with scores under + the thresholds are considered negative. It's only used + when ``pred`` is scores. None means no thresholds. + Defaults to (0., ). + thrs (Sequence[float]): Predictions with scores under + the thresholds are considered negative. It's only used + when ``pred`` is scores. Defaults to (0., ). + + Returns: + torch.Tensor | List[List[torch.Tensor]]: Accuracy. + + - torch.Tensor: If the ``pred`` is a sequence of label instead of + score (number of dimensions is 1). Only return a top-1 accuracy + tensor, and ignore the argument ``topk` and ``thrs``. + - List[List[torch.Tensor]]: If the ``pred`` is a sequence of score + (number of dimensions is 2). Return the accuracy on each ``topk`` + and ``thrs``. And the first dim is ``topk``, the second dim is + ``thrs``. + """ + + pred = to_tensor(pred) + target = to_tensor(target).to(torch.int64) + num = pred.size(0) + assert pred.size(0) == target.size(0), \ + f"The size of pred ({pred.size(0)}) doesn't match "\ + f'the target ({target.size(0)}).' + + if pred.ndim == 1: + # For pred label, ignore topk and acc + pred_label = pred.int() + correct = pred.eq(target).float().sum(0, keepdim=True) + acc = correct.mul_(100. / num) + return acc + else: + # For pred score, calculate on all topk and thresholds. + pred = pred.float() + maxk = max(topk) + + if maxk > pred.size(1): + raise ValueError( + f'Top-{maxk} accuracy is unavailable since the number of ' + f'categories is {pred.size(1)}.') + + pred_score, pred_label = pred.topk(maxk, dim=1) + pred_label = pred_label.t() + correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) + results = [] + for k in topk: + results.append([]) + for thr in thrs: + # Only prediction values larger than thr are counted + # as correct + _correct = correct + if thr is not None: + _correct = _correct & (pred_score.t() > thr) + correct_k = _correct[:k].reshape(-1).float().sum( + 0, keepdim=True) + acc = correct_k.mul_(100. / num) + results[-1].append(acc) + return results + + +@METRICS.register_module() +class SingleLabelMetric(BaseMetric): + r"""A collection of precision, recall, f1-score and support for + single-label tasks. + + The collection of metrics is for single-label multi-class classification. + And all these metrics are based on the confusion matrix of every category: + + .. image:: ../../_static/image/confusion-matrix.png + :width: 60% + :align: center + + All metrics can be formulated use variables above: + + **Precision** is the fraction of correct predictions in all predictions: + + .. math:: + \text{Precision} = \frac{TP}{TP+FP} + + **Recall** is the fraction of correct predictions in all targets: + + .. math:: + \text{Recall} = \frac{TP}{TP+FN} + + **F1-score** is the harmonic mean of the precision and recall: + + .. math:: + \text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}} + + **Support** is the number of samples: + + .. math:: + \text{Support} = TP + TN + FN + FP + + Args: + thrs (Sequence[float | None] | float | None): If a float, predictions + with score lower than the threshold will be regard as the negative + prediction. If None, only the top-1 prediction will be regard as + the positive prediction. If the parameter is a tuple, accuracy + based on all thresholds will be calculated and outputted together. + Defaults to 0. + items (Sequence[str]): The detailed metric items to evaluate, select + from "precision", "recall", "f1-score" and "support". + Defaults to ``('precision', 'recall', 'f1-score')``. + average (str | None): How to calculate the final metrics from the + confusion matrix of every category. It supports three modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories and + calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output directly. + + Defaults to "macro". + num_classes (int, optional): The number of classes. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.evaluation import SingleLabelMetric + >>> # -------------------- The Basic Usage -------------------- + >>> y_pred = [0, 1, 1, 3] + >>> y_true = [0, 2, 1, 3] + >>> # Output precision, recall, f1-score and support. + >>> SingleLabelMetric.calculate(y_pred, y_true, num_classes=4) + (tensor(62.5000), tensor(75.), tensor(66.6667), tensor(4)) + >>> # Calculate with different thresholds. + >>> y_score = torch.rand((1000, 10)) + >>> y_true = torch.zeros((1000, )) + >>> SingleLabelMetric.calculate(y_score, y_true, thrs=(0., 0.9)) + [(tensor(10.), tensor(0.9500), tensor(1.7352), tensor(1000)), + (tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))] + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmpretrain.structures import DataSample + >>> from mmengine.evaluator import Evaluator + >>> data_samples = [ + ... DataSample().set_gt_label(i%5).set_pred_score(torch.rand(5)) + ... for i in range(1000) + ... ] + >>> evaluator = Evaluator(metrics=SingleLabelMetric()) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + {'single-label/precision': 19.650691986083984, + 'single-label/recall': 19.600000381469727, + 'single-label/f1-score': 19.619548797607422} + >>> # Evaluate on each class + >>> evaluator = Evaluator(metrics=SingleLabelMetric(average=None)) + >>> evaluator.process(data_samples) + >>> evaluator.evaluate(1000) + { + 'single-label/precision_classwise': [21.1, 18.7, 17.8, 19.4, 16.1], + 'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0], + 'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0] + } + """ # noqa: E501 + default_prefix: Optional[str] = 'single-label' + + def __init__(self, + thrs: Union[float, Sequence[Union[float, None]], None] = 0., + items: Sequence[str] = ('precision', 'recall', 'f1-score'), + average: Optional[str] = 'macro', + num_classes: Optional[int] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + if isinstance(thrs, float) or thrs is None: + self.thrs = (thrs, ) + else: + self.thrs = tuple(thrs) + + for item in items: + assert item in ['precision', 'recall', 'f1-score', 'support'], \ + f'The metric {item} is not supported by `SingleLabelMetric`,' \ + ' please specify from "precision", "recall", "f1-score" and ' \ + '"support".' + self.items = tuple(items) + self.average = average + self.num_classes = num_classes + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + + for data_sample in data_samples: + result = dict() + if 'pred_score' in data_sample: + result['pred_score'] = data_sample['pred_score'].cpu() + else: + num_classes = self.num_classes or data_sample.get( + 'num_classes') + assert num_classes is not None, \ + 'The `num_classes` must be specified if no `pred_score`.' + result['pred_label'] = data_sample['pred_label'].cpu() + result['num_classes'] = num_classes + result['gt_label'] = data_sample['gt_label'].cpu() + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. `self.results` + # are a list of results from multiple batch, while the input `results` + # are the collected results. + metrics = {} + + def pack_results(precision, recall, f1_score, support): + single_metrics = {} + if 'precision' in self.items: + single_metrics['precision'] = precision + if 'recall' in self.items: + single_metrics['recall'] = recall + if 'f1-score' in self.items: + single_metrics['f1-score'] = f1_score + if 'support' in self.items: + single_metrics['support'] = support + return single_metrics + + # concat + target = torch.cat([res['gt_label'] for res in results]) + if 'pred_score' in results[0]: + pred = torch.stack([res['pred_score'] for res in results]) + metrics_list = self.calculate( + pred, target, thrs=self.thrs, average=self.average) + + multi_thrs = len(self.thrs) > 1 + for i, thr in enumerate(self.thrs): + if multi_thrs: + suffix = '_no-thr' if thr is None else f'_thr-{thr:.2f}' + else: + suffix = '' + + for k, v in pack_results(*metrics_list[i]).items(): + metrics[k + suffix] = v + else: + # If only label in the `pred_label`. + pred = torch.cat([res['pred_label'] for res in results]) + res = self.calculate( + pred, + target, + average=self.average, + num_classes=results[0]['num_classes']) + metrics = pack_results(*res) + + result_metrics = dict() + for k, v in metrics.items(): + + if self.average is None: + result_metrics[k + '_classwise'] = v.cpu().detach().tolist() + elif self.average == 'micro': + result_metrics[k + f'_{self.average}'] = v.item() + else: + result_metrics[k] = v.item() + + return result_metrics + + @staticmethod + def calculate( + pred: Union[torch.Tensor, np.ndarray, Sequence], + target: Union[torch.Tensor, np.ndarray, Sequence], + thrs: Sequence[Union[float, None]] = (0., ), + average: Optional[str] = 'macro', + num_classes: Optional[int] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Calculate the precision, recall, f1-score and support. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. It can be labels (N, ), or scores of every + class (N, C). + target (torch.Tensor | np.ndarray | Sequence): The target of + each prediction with shape (N, ). + thrs (Sequence[float | None]): Predictions with scores under + the thresholds are considered negative. It's only used + when ``pred`` is scores. None means no thresholds. + Defaults to (0., ). + average (str | None): How to calculate the final metrics from + the confusion matrix of every category. It supports three + modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Average the confusion matrix over all categories + and calculate metrics on the mean confusion matrix. + - `None`: Calculate metrics of every category and output + directly. + + Defaults to "macro". + num_classes (Optional, int): The number of classes. If the ``pred`` + is label instead of scores, this argument is required. + Defaults to None. + + Returns: + Tuple: The tuple contains precision, recall and f1-score. + And the type of each item is: + + - torch.Tensor: If the ``pred`` is a sequence of label instead of + score (number of dimensions is 1). Only returns a tensor for + each metric. The shape is (1, ) if ``classwise`` is False, and + (C, ) if ``classwise`` is True. + - List[torch.Tensor]: If the ``pred`` is a sequence of score + (number of dimensions is 2). Return the metrics on each ``thrs``. + The shape of tensor is (1, ) if ``classwise`` is False, and (C, ) + if ``classwise`` is True. + """ + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specify from {average_options}.' + + pred = to_tensor(pred) + target = to_tensor(target).to(torch.int64) + assert pred.size(0) == target.size(0), \ + f"The size of pred ({pred.size(0)}) doesn't match "\ + f'the target ({target.size(0)}).' + + if pred.ndim == 1: + assert num_classes is not None, \ + 'Please specify the `num_classes` if the `pred` is labels ' \ + 'intead of scores.' + gt_positive = F.one_hot(target.flatten(), num_classes) + pred_positive = F.one_hot(pred.to(torch.int64), num_classes) + return _precision_recall_f1_support(pred_positive, gt_positive, + average) + else: + # For pred score, calculate on all thresholds. + num_classes = pred.size(1) + pred_score, pred_label = torch.topk(pred, k=1) + pred_score = pred_score.flatten() + pred_label = pred_label.flatten() + + gt_positive = F.one_hot(target.flatten(), num_classes) + + results = [] + for thr in thrs: + pred_positive = F.one_hot(pred_label, num_classes) + if thr is not None: + pred_positive[pred_score <= thr] = 0 + results.append( + _precision_recall_f1_support(pred_positive, gt_positive, + average)) + + return results + + +@METRICS.register_module() +class ConfusionMatrix(BaseMetric): + r"""A metric to calculate confusion matrix for single-label tasks. + + Args: + num_classes (int, optional): The number of classes. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + + 1. The basic usage. + + >>> import torch + >>> from mmpretrain.evaluation import ConfusionMatrix + >>> y_pred = [0, 1, 1, 3] + >>> y_true = [0, 2, 1, 3] + >>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4) + tensor([[1, 0, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1]]) + >>> # plot the confusion matrix + >>> import matplotlib.pyplot as plt + >>> y_score = torch.rand((1000, 10)) + >>> y_true = torch.randint(10, (1000, )) + >>> matrix = ConfusionMatrix.calculate(y_score, y_true) + >>> ConfusionMatrix().plot(matrix) + >>> plt.show() + + 2. In the config file + + .. code:: python + + val_evaluator = dict(type='ConfusionMatrix') + test_evaluator = dict(type='ConfusionMatrix') + """ # noqa: E501 + default_prefix = 'confusion_matrix' + + def __init__(self, + num_classes: Optional[int] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device, prefix) + + self.num_classes = num_classes + + def process(self, data_batch, data_samples: Sequence[dict]) -> None: + for data_sample in data_samples: + if 'pred_score' in data_sample: + pred_score = data_sample['pred_score'] + pred_label = pred_score.argmax(dim=0, keepdim=True) + self.num_classes = pred_score.size(0) + else: + pred_label = data_sample['pred_label'] + + self.results.append({ + 'pred_label': pred_label, + 'gt_label': data_sample['gt_label'], + }) + + def compute_metrics(self, results: list) -> dict: + pred_labels = [] + gt_labels = [] + for result in results: + pred_labels.append(result['pred_label']) + gt_labels.append(result['gt_label']) + confusion_matrix = ConfusionMatrix.calculate( + torch.cat(pred_labels), + torch.cat(gt_labels), + num_classes=self.num_classes) + return {'result': confusion_matrix} + + @staticmethod + def calculate(pred, target, num_classes=None) -> dict: + """Calculate the confusion matrix for single-label task. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. It can be labels (N, ), or scores of every + class (N, C). + target (torch.Tensor | np.ndarray | Sequence): The target of + each prediction with shape (N, ). + num_classes (Optional, int): The number of classes. If the ``pred`` + is label instead of scores, this argument is required. + Defaults to None. + + Returns: + torch.Tensor: The confusion matrix. + """ + pred = to_tensor(pred) + target_label = to_tensor(target).int() + + assert pred.size(0) == target_label.size(0), \ + f"The size of pred ({pred.size(0)}) doesn't match "\ + f'the target ({target_label.size(0)}).' + assert target_label.ndim == 1 + + if pred.ndim == 1: + assert num_classes is not None, \ + 'Please specify the `num_classes` if the `pred` is labels ' \ + 'intead of scores.' + pred_label = pred + else: + num_classes = num_classes or pred.size(1) + pred_label = torch.argmax(pred, dim=1).flatten() + + with torch.no_grad(): + indices = num_classes * target_label + pred_label + matrix = torch.bincount(indices, minlength=num_classes**2) + matrix = matrix.reshape(num_classes, num_classes) + + return matrix + + @staticmethod + def plot(confusion_matrix: torch.Tensor, + include_values: bool = False, + cmap: str = 'viridis', + classes: Optional[List[str]] = None, + colorbar: bool = True, + show: bool = True): + """Draw a confusion matrix by matplotlib. + + Modified from `Scikit-Learn + `_ + + Args: + confusion_matrix (torch.Tensor): The confusion matrix to draw. + include_values (bool): Whether to draw the values in the figure. + Defaults to False. + cmap (str): The color map to use. Defaults to use "viridis". + classes (list[str], optional): The names of categories. + Defaults to None, which means to use index number. + colorbar (bool): Whether to show the colorbar. Defaults to True. + show (bool): Whether to show the figure immediately. + Defaults to True. + """ # noqa: E501 + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(10, 10)) + + num_classes = confusion_matrix.size(0) + + im_ = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap) + text_ = None + cmap_min, cmap_max = im_.cmap(0), im_.cmap(1.0) + + if include_values: + text_ = np.empty_like(confusion_matrix, dtype=object) + + # print text with appropriate color depending on background + thresh = (confusion_matrix.max() + confusion_matrix.min()) / 2.0 + + for i, j in product(range(num_classes), range(num_classes)): + color = cmap_max if confusion_matrix[i, + j] < thresh else cmap_min + + text_cm = format(confusion_matrix[i, j], '.2g') + text_d = format(confusion_matrix[i, j], 'd') + if len(text_d) < len(text_cm): + text_cm = text_d + + text_[i, j] = ax.text( + j, i, text_cm, ha='center', va='center', color=color) + + display_labels = classes or np.arange(num_classes) + + if colorbar: + fig.colorbar(im_, ax=ax) + ax.set( + xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=display_labels, + yticklabels=display_labels, + ylabel='True label', + xlabel='Predicted label', + ) + ax.invert_yaxis() + ax.xaxis.tick_top() + + ax.set_ylim((num_classes - 0.5, -0.5)) + # Automatically rotate the x labels. + fig.autofmt_xdate(ha='center') + + if show: + plt.show() + return fig diff --git a/mmpretrain/evaluation/metrics/visual_grounding_eval.py b/mmpretrain/evaluation/metrics/visual_grounding_eval.py new file mode 100644 index 0000000..ad16e5a --- /dev/null +++ b/mmpretrain/evaluation/metrics/visual_grounding_eval.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torchvision.ops.boxes as boxes +from mmengine.evaluator import BaseMetric + +from mmpretrain.registry import METRICS + + +def aligned_box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor): + area1 = boxes.box_area(boxes1) + area2 = boxes.box_area(boxes2) + + lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (B, 2) + rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (B, 2) + + wh = boxes._upcast(rb - lt).clamp(min=0) # (B, 2) + inter = wh[:, 0] * wh[:, 1] # (B, ) + + union = area1 + area2 - inter + iou = inter / union + return iou + + +@METRICS.register_module() +class VisualGroundingMetric(BaseMetric): + """Visual Grounding evaluator. + + Calculate the box mIOU and box grounding accuracy for visual grounding + model. + + Args: + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'visual-grounding' + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for preds in data_samples: + + pred_box = preds['pred_bboxes'].squeeze() + box_gt = torch.Tensor(preds['gt_bboxes']).squeeze() + + result = { + 'box': pred_box.to('cpu').squeeze(), + 'box_target': box_gt.squeeze(), + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + pred_boxes = torch.stack([each['box'] for each in results]) + gt_boxes = torch.stack([each['box_target'] for each in results]) + iou = aligned_box_iou(pred_boxes, gt_boxes) + accu_num = torch.sum(iou >= 0.5) + + miou = torch.mean(iou) + acc = accu_num / len(gt_boxes) + coco_val = {'miou': miou, 'acc': acc} + return coco_val diff --git a/mmpretrain/evaluation/metrics/voc_multi_label.py b/mmpretrain/evaluation/metrics/voc_multi_label.py new file mode 100644 index 0000000..1034852 --- /dev/null +++ b/mmpretrain/evaluation/metrics/voc_multi_label.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmpretrain.registry import METRICS +from mmpretrain.structures import label_to_onehot +from .multi_label import AveragePrecision, MultiLabelMetric + + +class VOCMetricMixin: + """A mixin class for VOC dataset metrics, VOC annotations have extra + `difficult` attribute for each object, therefore, extra option is needed + for calculating VOC metrics. + + Args: + difficult_as_postive (Optional[bool]): Whether to map the difficult + labels as positive in one-hot ground truth for evaluation. If it + set to True, map difficult gt labels to positive ones(1), If it + set to False, map difficult gt labels to negative ones(0). + Defaults to None, the difficult labels will be set to '-1'. + """ + + def __init__(self, + *arg, + difficult_as_positive: Optional[bool] = None, + **kwarg): + self.difficult_as_positive = difficult_as_positive + super().__init__(*arg, **kwarg) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + result = dict() + gt_label = data_sample['gt_label'] + gt_label_difficult = data_sample['gt_label_difficult'] + + result['pred_score'] = data_sample['pred_score'].clone() + num_classes = result['pred_score'].size()[-1] + + if 'gt_score' in data_sample: + result['gt_score'] = data_sample['gt_score'].clone() + else: + result['gt_score'] = label_to_onehot(gt_label, num_classes) + + # VOC annotation labels all the objects in a single image + # therefore, some categories are appeared both in + # difficult objects and non-difficult objects. + # Here we reckon those labels which are only exists in difficult + # objects as difficult labels. + difficult_label = set(gt_label_difficult) - ( + set(gt_label_difficult) & set(gt_label.tolist())) + + # set difficult label for better eval + if self.difficult_as_positive is None: + result['gt_score'][[*difficult_label]] = -1 + elif self.difficult_as_positive: + result['gt_score'][[*difficult_label]] = 1 + + # Save the result to `self.results`. + self.results.append(result) + + +@METRICS.register_module() +class VOCMultiLabelMetric(VOCMetricMixin, MultiLabelMetric): + """A collection of metrics for multi-label multi-class classification task + based on confusion matrix for VOC dataset. + + It includes precision, recall, f1-score and support. + + Args: + difficult_as_postive (Optional[bool]): Whether to map the difficult + labels as positive in one-hot ground truth for evaluation. If it + set to True, map difficult gt labels to positive ones(1), If it + set to False, map difficult gt labels to negative ones(0). + Defaults to None, the difficult labels will be set to '-1'. + **kwarg: Refers to `MultiLabelMetric` for detailed docstrings. + """ + + +@METRICS.register_module() +class VOCAveragePrecision(VOCMetricMixin, AveragePrecision): + """Calculate the average precision with respect of classes for VOC dataset. + + Args: + difficult_as_postive (Optional[bool]): Whether to map the difficult + labels as positive in one-hot ground truth for evaluation. If it + set to True, map difficult gt labels to positive ones(1), If it + set to False, map difficult gt labels to negative ones(0). + Defaults to None, the difficult labels will be set to '-1'. + **kwarg: Refers to `AveragePrecision` for detailed docstrings. + """ diff --git a/mmpretrain/evaluation/metrics/vqa.py b/mmpretrain/evaluation/metrics/vqa.py new file mode 100644 index 0000000..fd77ba9 --- /dev/null +++ b/mmpretrain/evaluation/metrics/vqa.py @@ -0,0 +1,315 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Partly adopted from https://github.com/GT-Vision-Lab/VQA +# Copyright (c) 2014, Aishwarya Agrawal +from typing import List, Optional + +import mmengine +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmpretrain.registry import METRICS + + +def _process_punctuation(inText): + import re + outText = inText + punct = [ + ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', + '>', '<', '@', '`', ',', '?', '!' + ] + commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605 + periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605 + for p in punct: + if (p + ' ' in inText or ' ' + p in inText) or (re.search( + commaStrip, inText) is not None): + outText = outText.replace(p, '') + else: + outText = outText.replace(p, ' ') + outText = periodStrip.sub('', outText, re.UNICODE) + return outText + + +def _process_digit_article(inText): + outText = [] + tempText = inText.lower().split() + articles = ['a', 'an', 'the'] + manualMap = { + 'none': '0', + 'zero': '0', + 'one': '1', + 'two': '2', + 'three': '3', + 'four': '4', + 'five': '5', + 'six': '6', + 'seven': '7', + 'eight': '8', + 'nine': '9', + 'ten': '10', + } + contractions = { + 'aint': "ain't", + 'arent': "aren't", + 'cant': "can't", + 'couldve': "could've", + 'couldnt': "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + 'didnt': "didn't", + 'doesnt': "doesn't", + 'dont': "don't", + 'hadnt': "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + 'hasnt': "hasn't", + 'havent': "haven't", + 'hed': "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + 'hes': "he's", + 'howd': "how'd", + 'howll': "how'll", + 'hows': "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + 'Im': "I'm", + 'Ive': "I've", + 'isnt': "isn't", + 'itd': "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + 'itll': "it'll", + "let's": "let's", + 'maam': "ma'am", + 'mightnt': "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + 'mightve': "might've", + 'mustnt': "mustn't", + 'mustve': "must've", + 'neednt': "needn't", + 'notve': "not've", + 'oclock': "o'clock", + 'oughtnt': "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + 'shant': "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + 'shouldve': "should've", + 'shouldnt': "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": 'somebodyd', + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + 'somebodyll': "somebody'll", + 'somebodys': "somebody's", + 'someoned': "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + 'someonell': "someone'll", + 'someones': "someone's", + 'somethingd': "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + 'somethingll': "something'll", + 'thats': "that's", + 'thered': "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + 'therere': "there're", + 'theres': "there's", + 'theyd': "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + 'theyll': "they'll", + 'theyre': "they're", + 'theyve': "they've", + 'twas': "'twas", + 'wasnt': "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + 'weve': "we've", + 'werent': "weren't", + 'whatll': "what'll", + 'whatre': "what're", + 'whats': "what's", + 'whatve': "what've", + 'whens': "when's", + 'whered': "where'd", + 'wheres': "where's", + 'whereve': "where've", + 'whod': "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + 'wholl': "who'll", + 'whos': "who's", + 'whove': "who've", + 'whyll': "why'll", + 'whyre': "why're", + 'whys': "why's", + 'wont': "won't", + 'wouldve': "would've", + 'wouldnt': "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + 'yall': "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + 'youd': "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + 'youll': "you'll", + 'youre': "you're", + 'youve': "you've", + } + for word in tempText: + word = manualMap.setdefault(word, word) + if word not in articles: + outText.append(word) + for wordId, word in enumerate(outText): + if word in contractions: + outText[wordId] = contractions[word] + outText = ' '.join(outText) + return outText + + +@METRICS.register_module() +class VQAAcc(BaseMetric): + '''VQA Acc metric. + Args: + + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + ''' + default_prefix = 'VQA' + + def __init__(self, + full_score_weight: float = 0.3, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + self.full_score_weight = full_score_weight + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + gt_answer = sample.get('gt_answer') + gt_answer_weight = sample.get('gt_answer_weight') + if isinstance(gt_answer, str): + gt_answer = [gt_answer] + if gt_answer_weight is None: + gt_answer_weight = [1. / (len(gt_answer))] * len(gt_answer) + + result = { + 'pred_answer': sample.get('pred_answer'), + 'gt_answer': gt_answer, + 'gt_answer_weight': gt_answer_weight, + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + acc = [] + for result in results: + pred_answer = self._process_answer(result['pred_answer']) + gt_answer = [ + self._process_answer(answer) for answer in result['gt_answer'] + ] + answer_weight = result['gt_answer_weight'] + + weight_sum = 0 + for i, gt in enumerate(gt_answer): + if gt == pred_answer: + weight_sum += answer_weight[i] + vqa_acc = min(1.0, weight_sum / self.full_score_weight) + acc.append(vqa_acc) + + accuracy = sum(acc) / len(acc) * 100 + + metrics = {'acc': accuracy} + return metrics + + def _process_answer(self, answer): + answer = answer.replace('\n', ' ') + answer = answer.replace('\t', ' ') + answer = answer.strip() + answer = _process_punctuation(answer) + answer = _process_digit_article(answer) + return answer + + +@METRICS.register_module() +class ReportVQA(BaseMetric): + """Dump VQA result to the standard json format for VQA evaluation. + + Args: + file_path (str): The file path to save the result file. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'VQA' + + def __init__(self, + file_path: str, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + if not file_path.endswith('.json'): + raise ValueError('The output file must be a json file.') + self.file_path = file_path + + def process(self, data_batch, data_samples) -> None: + """transfer tensors in predictions to CPU.""" + for sample in data_samples: + question_id = sample['question_id'] + pred_answer = sample['pred_answer'] + + result = { + 'question_id': int(question_id), + 'answer': pred_answer, + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Dump the result to json file.""" + mmengine.dump(results, self.file_path) + logger = MMLogger.get_current_instance() + logger.info(f'Results has been saved to {self.file_path}.') + return {} diff --git a/mmpretrain/models/__init__.py b/mmpretrain/models/__init__.py new file mode 100644 index 0000000..3f58311 --- /dev/null +++ b/mmpretrain/models/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backbones import * # noqa: F401,F403 +from .builder import (BACKBONES, CLASSIFIERS, HEADS, LOSSES, NECKS, + build_backbone, build_classifier, build_head, build_loss, + build_neck) +from .classifiers import * # noqa: F401,F403 +from .heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .multimodal import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .peft import * # noqa: F401,F403 +from .retrievers import * # noqa: F401,F403 +from .selfsup import * # noqa: F401,F403 +from .tta import * # noqa: F401,F403 +from .utils import * # noqa: F401,F403 + +__all__ = [ + 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'CLASSIFIERS', 'build_backbone', + 'build_head', 'build_neck', 'build_loss', 'build_classifier' +] diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py new file mode 100644 index 0000000..60e37fb --- /dev/null +++ b/mmpretrain/models/backbones/__init__.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .alexnet import AlexNet +from .beit import BEiTViT +from .conformer import Conformer +from .convmixer import ConvMixer +from .convnext import ConvNeXt +from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt +from .davit import DaViT +from .deit import DistilledVisionTransformer +from .deit3 import DeiT3 +from .densenet import DenseNet +from .edgenext import EdgeNeXt +from .efficientformer import EfficientFormer +from .efficientnet import EfficientNet +from .efficientnet_v2 import EfficientNetV2 +from .hivit import HiViT +from .hornet import HorNet +from .hrnet import HRNet +from .inception_v3 import InceptionV3 +from .lenet import LeNet5 +from .levit import LeViT +from .mixmim import MixMIMTransformer +from .mlp_mixer import MlpMixer +from .mobilenet_v2 import MobileNetV2 +from .mobilenet_v3 import MobileNetV3 +from .mobileone import MobileOne +from .mobilevit import MobileViT +from .mvit import MViT +from .poolformer import PoolFormer +from .regnet import RegNet +from .replknet import RepLKNet +from .repmlp import RepMLPNet +from .repvgg import RepVGG +from .res2net import Res2Net +from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1c, ResNetV1d +from .resnet_cifar import ResNet_CIFAR +from .resnext import ResNeXt +from .revvit import RevVisionTransformer +from .riformer import RIFormer +from .seresnet import SEResNet +from .seresnext import SEResNeXt +from .shufflenet_v1 import ShuffleNetV1 +from .shufflenet_v2 import ShuffleNetV2 +from .sparse_convnext import SparseConvNeXt +from .sparse_resnet import SparseResNet +from .swin_transformer import SwinTransformer +from .swin_transformer_v2 import SwinTransformerV2 +from .t2t_vit import T2T_ViT +from .timm_backbone import TIMMBackbone +from .tinyvit import TinyViT +from .tnt import TNT +from .twins import PCPVT, SVT +from .van import VAN +from .vgg import VGG +from .vig import PyramidVig, Vig +from .vision_transformer import VisionTransformer +from .vit_eva02 import ViTEVA02 +from .vit_sam import ViTSAM +from .xcit import XCiT + +__all__ = [ + 'LeNet5', + 'AlexNet', + 'VGG', + 'RegNet', + 'ResNet', + 'ResNeXt', + 'ResNetV1d', + 'ResNeSt', + 'ResNet_CIFAR', + 'SEResNet', + 'SEResNeXt', + 'ShuffleNetV1', + 'ShuffleNetV2', + 'MobileNetV2', + 'MobileNetV3', + 'VisionTransformer', + 'SwinTransformer', + 'TNT', + 'TIMMBackbone', + 'T2T_ViT', + 'Res2Net', + 'RepVGG', + 'Conformer', + 'MlpMixer', + 'DistilledVisionTransformer', + 'PCPVT', + 'SVT', + 'EfficientNet', + 'EfficientNetV2', + 'ConvNeXt', + 'HRNet', + 'ResNetV1c', + 'ConvMixer', + 'EdgeNeXt', + 'CSPDarkNet', + 'CSPResNet', + 'CSPResNeXt', + 'CSPNet', + 'RepLKNet', + 'RepMLPNet', + 'PoolFormer', + 'RIFormer', + 'DenseNet', + 'VAN', + 'InceptionV3', + 'MobileOne', + 'EfficientFormer', + 'SwinTransformerV2', + 'MViT', + 'DeiT3', + 'HorNet', + 'MobileViT', + 'DaViT', + 'BEiTViT', + 'RevVisionTransformer', + 'MixMIMTransformer', + 'TinyViT', + 'LeViT', + 'Vig', + 'PyramidVig', + 'XCiT', + 'ViTSAM', + 'ViTEVA02', + 'HiViT', + 'SparseResNet', + 'SparseConvNeXt', +] diff --git a/mmpretrain/models/backbones/alexnet.py b/mmpretrain/models/backbones/alexnet.py new file mode 100644 index 0000000..f7c2891 --- /dev/null +++ b/mmpretrain/models/backbones/alexnet.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class AlexNet(BaseBackbone): + """`AlexNet `_ backbone. + + The input for AlexNet is a 224x224 RGB image. + + Args: + num_classes (int): number of classes for classification. + The default value is -1, which uses the backbone as + a feature extractor without the top classifier. + """ + + def __init__(self, num_classes=-1): + super(AlexNet, self).__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + + x = self.features(x) + if self.num_classes > 0: + x = x.view(x.size(0), 256 * 6 * 6) + x = self.classifier(x) + + return (x, ) diff --git a/mmpretrain/models/backbones/base_backbone.py b/mmpretrain/models/backbones/base_backbone.py new file mode 100644 index 0000000..751aa95 --- /dev/null +++ b/mmpretrain/models/backbones/base_backbone.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +from mmengine.model import BaseModule + + +class BaseBackbone(BaseModule, metaclass=ABCMeta): + """Base backbone. + + This class defines the basic functions of a backbone. Any backbone that + inherits this class should at least define its own `forward` function. + """ + + def __init__(self, init_cfg=None): + super(BaseBackbone, self).__init__(init_cfg) + + @abstractmethod + def forward(self, x): + """Forward computation. + + Args: + x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of + Torch.tensor, containing input data for forward computation. + """ + pass + + def train(self, mode=True): + """Set module status before forward computation. + + Args: + mode (bool): Whether it is train_mode or test_mode + """ + super(BaseBackbone, self).train(mode) diff --git a/mmpretrain/models/backbones/beit.py b/mmpretrain/models/backbones/beit.py new file mode 100644 index 0000000..3c7d908 --- /dev/null +++ b/mmpretrain/models/backbones/beit.py @@ -0,0 +1,697 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import (BEiTAttention, build_norm_layer, resize_pos_embed, + resize_relative_position_bias_table, to_2tuple) +from .base_backbone import BaseBackbone +from .vision_transformer import TransformerEncoderLayer + + +class RelativePositionBias(BaseModule): + """Relative Position Bias. + + This module is copied from + https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py#L209. + + Args: + window_size (Sequence[int]): The window size of the relative + position bias. + num_heads (int): The number of head in multi-head attention. + with_cls_token (bool): To indicate the backbone has cls_token or not. + Defaults to True. + """ + + def __init__( + self, + window_size: Sequence[int], + num_heads: int, + with_cls_token: bool = True, + ) -> None: + super().__init__() + self.window_size = window_size + if with_cls_token: + num_extra_tokens = 3 + else: + num_extra_tokens = 0 + # cls to token & token to cls & cls to cls + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1) + num_extra_tokens + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each + # token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] -\ + coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + if with_cls_token: + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1, ) * 2, + dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum( + -1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + else: + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1], ) * 2, + dtype=relative_coords.dtype) + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + + self.register_buffer('relative_position_index', + relative_position_index) + + def forward(self) -> torch.Tensor: + # Wh*Ww,Wh*Ww,nH + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) + return relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class BEiTTransformerEncoderLayer(TransformerEncoderLayer): + """Implements one encoder layer in BEiT. + + Comparing with conventional ``TransformerEncoderLayer``, this module + adds weights to the shortcut connection. In addition, ``BEiTAttention`` + is used to replace the original ``MultiheadAttention`` in + ``TransformerEncoderLayer``. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. 1 means no scaling. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + window_size (tuple[int]): The height and width of the window. + Defaults to None. + use_rel_pos_bias (bool): Whether to use unique relative position bias, + if False, use shared relative position bias defined in backbone. + attn_drop_rate (float): The drop out rate for attention layer. + Defaults to 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='LN'). + attn_cfg (dict): The configuration for the attention layer. + Defaults to an empty dict. + ffn_cfg (dict): The configuration for the ffn layer. + Defaults to ``dict(add_identity=False)``. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + layer_scale_init_value: float, + window_size: Tuple[int, int], + use_rel_pos_bias: bool, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + bias: Union[str, bool] = 'qv_bias', + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + attn_cfg: dict = dict(), + ffn_cfg: dict = dict(add_identity=False), + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + attn_drop_rate=attn_drop_rate, + drop_path_rate=0., + drop_rate=0., + num_fcs=num_fcs, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + attn_cfg = { + 'window_size': window_size, + 'use_rel_pos_bias': use_rel_pos_bias, + 'qk_scale': None, + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'attn_drop': attn_drop_rate, + 'proj_drop': drop_rate, + 'bias': bias, + **attn_cfg, + } + self.attn = BEiTAttention(**attn_cfg) + + ffn_cfg = { + 'embed_dims': embed_dims, + 'feedforward_channels': feedforward_channels, + 'num_fcs': num_fcs, + 'ffn_drop': drop_rate, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path_rate), + 'act_cfg': act_cfg, + **ffn_cfg, + } + self.ffn = FFN(**ffn_cfg) + + # NOTE: drop path for stochastic depth, we shall see if + # this is better than dropout here + dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate) + self.drop_path = build_dropout( + dropout_layer) if dropout_layer else nn.Identity() + + if layer_scale_init_value > 0: + self.gamma_1 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + self.gamma_2 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x: torch.Tensor, + rel_pos_bias: torch.Tensor) -> torch.Tensor: + if self.gamma_1 is None: + x = x + self.drop_path( + self.attn(self.ln1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.ffn(self.ln2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn( + self.ln1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.ffn(self.ln2(x))) + return x + + +@MODELS.register_module() +class BEiTViT(BaseBackbone): + """Backbone for BEiT. + + A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers + `_ + A PyTorch implement of : `BEiT v2: Masked Image Modeling with + Vector-Quantized Visual Tokenizers `_ + + Args: + arch (str | dict): BEiT architecture. If use string, choose from + 'base', 'large'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"avg_featmap"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + use_abs_pos_emb (bool): Use position embedding like vanilla ViT. + Defaults to False. + use_rel_pos_bias (bool): Use relative position embedding in each + transformer encoder layer. Defaults to True. + use_shared_rel_pos_bias (bool): Use shared relative position embedding, + all transformer encoder layers share the same relative position + embedding. Defaults to False. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. Defaults to 0.1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 768, + 'num_layers': 8, + 'num_heads': 8, + 'feedforward_channels': 768 * 3, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['eva-g', 'eva-giant'], + { + # The implementation in EVA + # + 'embed_dims': 1408, + 'num_layers': 40, + 'num_heads': 16, + 'feedforward_channels': 6144 + }), + **dict.fromkeys( + ['deit-t', 'deit-tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': 192 * 4 + }), + **dict.fromkeys( + ['deit-s', 'deit-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 384 * 4 + }), + **dict.fromkeys( + ['deit-b', 'deit-base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 768 * 4 + }), + } + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0, + drop_path_rate=0, + bias='qv_bias', + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=False, + out_type='avg_featmap', + with_cls_token=True, + frozen_stages=-1, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + interpolate_mode='bicubic', + layer_scale_init_value=0.1, + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super(BEiTViT, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + self.with_cls_token = with_cls_token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + if use_abs_pos_emb: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_extra_tokens, + self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + else: + self.pos_embed = None + self.drop_after_pos = nn.Dropout(p=drop_rate) + + assert not (use_rel_pos_bias and use_shared_rel_pos_bias), ( + '`use_rel_pos_bias` and `use_shared_rel_pos_bias` cannot be set ' + 'to True at the same time') + self.use_rel_pos_bias = use_rel_pos_bias + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias( + window_size=self.patch_resolution, + num_heads=self.arch_settings['num_heads']) + else: + self.rel_pos_bias = None + self._register_load_state_dict_pre_hook( + self._prepare_relative_position_bias_table) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + layer_scale_init_value=layer_scale_init_value, + window_size=self.patch_resolution, + use_rel_pos_bias=use_rel_pos_bias, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + bias=bias, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg)) + + self.frozen_stages = frozen_stages + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + if out_type == 'avg_featmap': + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def init_weights(self): + super(BEiTViT, self).init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if (not self.with_cls_token + and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1): + # Remove cls token from state dict if it's not used. + state_dict[name] = state_dict[name][:, 1:] + ckpt_pos_embed_shape = state_dict[name].shape + + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + @staticmethod + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze cls_token + if self.with_cls_token: + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers): + if self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + if self.out_type == 'avg_featmap': + self.ln2.eval() + for param in self.ln2.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + if self.pos_embed is not None: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + rel_pos_bias = self.rel_pos_bias() \ + if self.rel_pos_bias is not None else None + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, rel_pos_bias) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return self.ln2(patch_token.mean(dim=1)) + + def _prepare_relative_position_bias_table(self, state_dict, prefix, *args, + **kwargs): + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + + if self.use_rel_pos_bias and 'rel_pos_bias.relative_position_bias_table' in state_dict: # noqa:E501 + logger.info('Expand the shared relative position embedding to ' + 'each transformer block.') + rel_pos_bias = state_dict[ + 'rel_pos_bias.relative_position_bias_table'] + for i in range(self.num_layers): + state_dict[ + f'layers.{i}.attn.relative_position_bias_table'] = \ + rel_pos_bias.clone() + state_dict.pop('rel_pos_bias.relative_position_bias_table') + state_dict.pop('rel_pos_bias.relative_position_index') + + state_dict_model = self.state_dict() + all_keys = list(state_dict_model.keys()) + for key in all_keys: + if 'relative_position_bias_table' in key: + ckpt_key = prefix + key + if ckpt_key not in state_dict: + continue + rel_pos_bias_pretrained = state_dict[ckpt_key] + rel_pos_bias_current = state_dict_model[key] + L1, nH1 = rel_pos_bias_pretrained.size() + L2, nH2 = rel_pos_bias_current.size() + src_size = int((L1 - 3)**0.5) + dst_size = int((L2 - 3)**0.5) + if L1 != L2: + extra_tokens = rel_pos_bias_pretrained[-3:, :] + rel_pos_bias = rel_pos_bias_pretrained[:-3, :] + + new_rel_pos_bias = resize_relative_position_bias_table( + src_size, dst_size, rel_pos_bias, nH1) + new_rel_pos_bias = torch.cat( + (new_rel_pos_bias, extra_tokens), dim=0) + logger.info('Resize the relative_position_bias_table from ' + f'{state_dict[ckpt_key].shape} to ' + f'{new_rel_pos_bias.shape}') + state_dict[ckpt_key] = new_rel_pos_bias + + # The index buffer need to be re-generated. + index_buffer = ckpt_key.replace('bias_table', 'index') + if index_buffer in state_dict: + del state_dict[index_buffer] + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in ('cls_token', 'pos_embed'): + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/conformer.py b/mmpretrain/models/backbones/conformer.py new file mode 100644 index 0000000..eda72b0 --- /dev/null +++ b/mmpretrain/models/backbones/conformer.py @@ -0,0 +1,621 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.drop import DropPath +from mmcv.cnn.bricks.transformer import AdaptivePadding +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone +from .vision_transformer import TransformerEncoderLayer + + +class ConvBlock(BaseModule): + """Basic convluation block used in Conformer. + + This block includes three convluation modules, and supports three new + functions: + 1. Returns the output of both the final layers and the second convluation + module. + 2. Fuses the input of the second convluation module with an extra input + feature map. + 3. Supports to add an extra convluation module to the identity connection. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + stride (int): The stride of the second convluation module. + Defaults to 1. + groups (int): The groups of the second convluation module. + Defaults to 1. + drop_path_rate (float): The rate of the DropPath layer. Defaults to 0. + with_residual_conv (bool): Whether to add an extra convluation module + to the identity connection. Defaults to False. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='BN', eps=1e-6)``. + act_cfg (dict): The config of activative functions. + Defaults to ``dict(type='ReLU', inplace=True))``. + init_cfg (dict, optional): The extra config to initialize the module. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + groups=1, + drop_path_rate=0., + with_residual_conv=False, + norm_cfg=dict(type='BN', eps=1e-6), + act_cfg=dict(type='ReLU', inplace=True), + init_cfg=None): + super(ConvBlock, self).__init__(init_cfg=init_cfg) + + expansion = 4 + mid_channels = out_channels // expansion + + self.conv1 = nn.Conv2d( + in_channels, + mid_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn1 = build_norm_layer(norm_cfg, mid_channels)[1] + self.act1 = build_activation_layer(act_cfg) + + self.conv2 = nn.Conv2d( + mid_channels, + mid_channels, + kernel_size=3, + stride=stride, + groups=groups, + padding=1, + bias=False) + self.bn2 = build_norm_layer(norm_cfg, mid_channels)[1] + self.act2 = build_activation_layer(act_cfg) + + self.conv3 = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.bn3 = build_norm_layer(norm_cfg, out_channels)[1] + self.act3 = build_activation_layer(act_cfg) + + if with_residual_conv: + self.residual_conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + padding=0, + bias=False) + self.residual_bn = build_norm_layer(norm_cfg, out_channels)[1] + + self.with_residual_conv = with_residual_conv + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x, fusion_features=None, out_conv2=True): + identity = x + + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) if fusion_features is None else self.conv2( + x + fusion_features) + x = self.bn2(x) + x2 = self.act2(x) + + x = self.conv3(x2) + x = self.bn3(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.with_residual_conv: + identity = self.residual_conv(identity) + identity = self.residual_bn(identity) + + x += identity + x = self.act3(x) + + if out_conv2: + return x, x2 + else: + return x + + +class FCUDown(BaseModule): + """CNN feature maps -> Transformer patch embeddings.""" + + def __init__(self, + in_channels, + out_channels, + down_stride, + with_cls_token=True, + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(FCUDown, self).__init__(init_cfg=init_cfg) + self.down_stride = down_stride + self.with_cls_token = with_cls_token + + self.conv_project = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.sample_pooling = nn.AvgPool2d( + kernel_size=down_stride, stride=down_stride) + + self.ln = build_norm_layer(norm_cfg, out_channels)[1] + self.act = build_activation_layer(act_cfg) + + def forward(self, x, x_t): + x = self.conv_project(x) # [N, C, H, W] + + x = self.sample_pooling(x).flatten(2).transpose(1, 2) + x = self.ln(x) + x = self.act(x) + + if self.with_cls_token: + x = torch.cat([x_t[:, 0][:, None, :], x], dim=1) + + return x + + +class FCUUp(BaseModule): + """Transformer patch embeddings -> CNN feature maps.""" + + def __init__(self, + in_channels, + out_channels, + up_stride, + with_cls_token=True, + norm_cfg=dict(type='BN', eps=1e-6), + act_cfg=dict(type='ReLU', inplace=True), + init_cfg=None): + super(FCUUp, self).__init__(init_cfg=init_cfg) + + self.up_stride = up_stride + self.with_cls_token = with_cls_token + + self.conv_project = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.bn = build_norm_layer(norm_cfg, out_channels)[1] + self.act = build_activation_layer(act_cfg) + + def forward(self, x, H, W): + B, _, C = x.shape + # [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14] + if self.with_cls_token: + x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W) + else: + x_r = x.transpose(1, 2).reshape(B, C, H, W) + + x_r = self.act(self.bn(self.conv_project(x_r))) + + return F.interpolate( + x_r, size=(H * self.up_stride, W * self.up_stride)) + + +class ConvTransBlock(BaseModule): + """Basic module for Conformer. + + This module is a fusion of CNN block transformer encoder block. + + Args: + in_channels (int): The number of input channels in conv blocks. + out_channels (int): The number of output channels in conv blocks. + embed_dims (int): The embedding dimension in transformer blocks. + conv_stride (int): The stride of conv2d layers. Defaults to 1. + groups (int): The groups of conv blocks. Defaults to 1. + with_residual_conv (bool): Whether to add a conv-bn layer to the + identity connect in the conv block. Defaults to False. + down_stride (int): The stride of the downsample pooling layer. + Defaults to 4. + num_heads (int): The number of heads in transformer attention layers. + Defaults to 12. + mlp_ratio (float): The expansion ratio in transformer FFN module. + Defaults to 4. + qkv_bias (bool): Enable bias for qkv if True. Defaults to False. + with_cls_token (bool): Whether use class token or not. + Defaults to True. + drop_rate (float): The dropout rate of the output projection and + FFN in the transformer block. Defaults to 0. + attn_drop_rate (float): The dropout rate after the attention + calculation in the transformer block. Defaults to 0. + drop_path_rate (bloat): The drop path rate in both the conv block + and the transformer block. Defaults to 0. + last_fusion (bool): Whether this block is the last stage. If so, + downsample the fusion feature map. + init_cfg (dict, optional): The extra config to initialize the module. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + embed_dims, + conv_stride=1, + groups=1, + with_residual_conv=False, + down_stride=4, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + with_cls_token=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + last_fusion=False, + init_cfg=None): + super(ConvTransBlock, self).__init__(init_cfg=init_cfg) + expansion = 4 + self.cnn_block = ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + with_residual_conv=with_residual_conv, + stride=conv_stride, + groups=groups) + + if last_fusion: + self.fusion_block = ConvBlock( + in_channels=out_channels, + out_channels=out_channels, + stride=2, + with_residual_conv=True, + groups=groups, + drop_path_rate=drop_path_rate) + else: + self.fusion_block = ConvBlock( + in_channels=out_channels, + out_channels=out_channels, + groups=groups, + drop_path_rate=drop_path_rate) + + self.squeeze_block = FCUDown( + in_channels=out_channels // expansion, + out_channels=embed_dims, + down_stride=down_stride, + with_cls_token=with_cls_token) + + self.expand_block = FCUUp( + in_channels=embed_dims, + out_channels=out_channels // expansion, + up_stride=down_stride, + with_cls_token=with_cls_token) + + self.trans_block = TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=int(embed_dims * mlp_ratio), + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + attn_drop_rate=attn_drop_rate, + qkv_bias=qkv_bias, + norm_cfg=dict(type='LN', eps=1e-6)) + + self.down_stride = down_stride + self.embed_dim = embed_dims + self.last_fusion = last_fusion + + def forward(self, cnn_input, trans_input): + x, x_conv2 = self.cnn_block(cnn_input, out_conv2=True) + + _, _, H, W = x_conv2.shape + + # Convert the feature map of conv2 to transformer embedding + # and concat with class token. + conv2_embedding = self.squeeze_block(x_conv2, trans_input) + + trans_output = self.trans_block(conv2_embedding + trans_input) + + # Convert the transformer output embedding to feature map + trans_features = self.expand_block(trans_output, H // self.down_stride, + W // self.down_stride) + x = self.fusion_block( + x, fusion_features=trans_features, out_conv2=False) + + return x, trans_output + + +@MODELS.register_module() +class Conformer(BaseBackbone): + """Conformer backbone. + + A PyTorch implementation of : `Conformer: Local Features Coupling Global + Representations for Visual Recognition `_ + + Args: + arch (str | dict): Conformer architecture. Defaults to 'tiny'. + patch_size (int): The patch size. Defaults to 16. + base_channels (int): The base number of channels in CNN network. + Defaults to 64. + mlp_ratio (float): The expansion ratio of FFN network in transformer + block. Defaults to 4. + with_cls_token (bool): Whether use class token or not. + Defaults to True. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 384, + 'channel_ratio': 1, + 'num_heads': 6, + 'depths': 12 + }), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 384, + 'channel_ratio': 4, + 'num_heads': 6, + 'depths': 12 + }), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 576, + 'channel_ratio': 6, + 'num_heads': 9, + 'depths': 12 + }), + } # yapf: disable + + _version = 1 + + def __init__(self, + arch='tiny', + patch_size=16, + base_channels=64, + mlp_ratio=4., + qkv_bias=True, + with_cls_token=True, + drop_path_rate=0., + norm_eval=True, + frozen_stages=0, + out_indices=-1, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'depths', 'num_heads', 'channel_ratio' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.num_features = self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.channel_ratio = self.arch_settings['channel_ratio'] + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.depths + index + 1 + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.norm_eval = norm_eval + self.frozen_stages = frozen_stages + + self.with_cls_token = with_cls_token + if self.with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + # stochastic depth decay rule + self.trans_dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, self.depths) + ] + + # Stem stage: get the feature maps by conv block + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, + bias=False) # 1 / 2 [112, 112] + self.bn1 = nn.BatchNorm2d(64) + self.act1 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56] + + assert patch_size % 16 == 0, 'The patch size of Conformer must ' \ + 'be divisible by 16.' + trans_down_stride = patch_size // 4 + + # To solve the issue #680 + # Auto pad the feature map to be divisible by trans_down_stride + self.auto_pad = AdaptivePadding(trans_down_stride, trans_down_stride) + + # 1 stage + stage1_channels = int(base_channels * self.channel_ratio) + self.conv_1 = ConvBlock( + in_channels=64, + out_channels=stage1_channels, + with_residual_conv=True, + stride=1) + self.trans_patch_conv = nn.Conv2d( + 64, + self.embed_dims, + kernel_size=trans_down_stride, + stride=trans_down_stride, + padding=0) + + self.trans_1 = TransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=int(self.embed_dims * mlp_ratio), + drop_path_rate=self.trans_dpr[0], + qkv_bias=qkv_bias, + norm_cfg=dict(type='LN', eps=1e-6)) + + # 2~4 stage + init_stage = 2 + fin_stage = self.depths // 3 + 1 + for i in range(init_stage, fin_stage): + self.add_module( + f'conv_trans_{i}', + ConvTransBlock( + in_channels=stage1_channels, + out_channels=stage1_channels, + embed_dims=self.embed_dims, + conv_stride=1, + with_residual_conv=False, + down_stride=trans_down_stride, + num_heads=self.num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rate=self.trans_dpr[i - 1], + with_cls_token=self.with_cls_token)) + + stage2_channels = int(base_channels * self.channel_ratio * 2) + # 5~8 stage + init_stage = fin_stage # 5 + fin_stage = fin_stage + self.depths // 3 # 9 + for i in range(init_stage, fin_stage): + if i == init_stage: + conv_stride = 2 + in_channels = stage1_channels + else: + conv_stride = 1 + in_channels = stage2_channels + + with_residual_conv = True if i == init_stage else False + self.add_module( + f'conv_trans_{i}', + ConvTransBlock( + in_channels=in_channels, + out_channels=stage2_channels, + embed_dims=self.embed_dims, + conv_stride=conv_stride, + with_residual_conv=with_residual_conv, + down_stride=trans_down_stride // 2, + num_heads=self.num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rate=self.trans_dpr[i - 1], + with_cls_token=self.with_cls_token)) + + stage3_channels = int(base_channels * self.channel_ratio * 2 * 2) + # 9~12 stage + init_stage = fin_stage # 9 + fin_stage = fin_stage + self.depths // 3 # 13 + for i in range(init_stage, fin_stage): + if i == init_stage: + conv_stride = 2 + in_channels = stage2_channels + with_residual_conv = True + else: + conv_stride = 1 + in_channels = stage3_channels + with_residual_conv = False + + last_fusion = (i == self.depths) + + self.add_module( + f'conv_trans_{i}', + ConvTransBlock( + in_channels=in_channels, + out_channels=stage3_channels, + embed_dims=self.embed_dims, + conv_stride=conv_stride, + with_residual_conv=with_residual_conv, + down_stride=trans_down_stride // 4, + num_heads=self.num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rate=self.trans_dpr[i - 1], + with_cls_token=self.with_cls_token, + last_fusion=last_fusion)) + self.fin_stage = fin_stage + + self.pooling = nn.AdaptiveAvgPool2d(1) + self.trans_norm = nn.LayerNorm(self.embed_dims) + + if self.with_cls_token: + trunc_normal_(self.cls_token, std=.02) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def init_weights(self): + super(Conformer, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + self.apply(self._init_weights) + + def forward(self, x): + output = [] + B = x.shape[0] + if self.with_cls_token: + cls_tokens = self.cls_token.expand(B, -1, -1) + + # stem + x_base = self.maxpool(self.act1(self.bn1(self.conv1(x)))) + x_base = self.auto_pad(x_base) + + # 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56] + x = self.conv_1(x_base, out_conv2=False) + x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2) + if self.with_cls_token: + x_t = torch.cat([cls_tokens, x_t], dim=1) + x_t = self.trans_1(x_t) + + # 2 ~ final + for i in range(2, self.fin_stage): + stage = getattr(self, f'conv_trans_{i}') + x, x_t = stage(x, x_t) + if i in self.out_indices: + if self.with_cls_token: + output.append([ + self.pooling(x).flatten(1), + self.trans_norm(x_t)[:, 0] + ]) + else: + # if no class token, use the mean patch token + # as the transformer feature. + output.append([ + self.pooling(x).flatten(1), + self.trans_norm(x_t).mean(dim=1) + ]) + + return tuple(output) diff --git a/mmpretrain/models/backbones/convmixer.py b/mmpretrain/models/backbones/convmixer.py new file mode 100644 index 0000000..480050d --- /dev/null +++ b/mmpretrain/models/backbones/convmixer.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer, + build_norm_layer) +from mmengine.utils import digit_version + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class Residual(nn.Module): + + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +@MODELS.register_module() +class ConvMixer(BaseBackbone): + """ConvMixer. . + + A PyTorch implementation of : `Patches Are All You Need? + `_ + + Modified from the `official repo + `_ + and `timm + `_. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``ConvMixer.arch_settings``. And if dict, it + should include the following two keys: + + - embed_dims (int): The dimensions of patch embedding. + - depth (int): Number of repetitions of ConvMixer Layer. + - patch_size (int): The patch size. + - kernel_size (int): The kernel size of depthwise conv layers. + + Defaults to '768/32'. + in_channels (int): Number of input image channels. Defaults to 3. + patch_size (int): The size of one patch in the patch embed layer. + Defaults to 7. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation after each convolution. + Defaults to ``dict(type='GELU')``. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): Initialization config dict. + """ + arch_settings = { + '768/32': { + 'embed_dims': 768, + 'depth': 32, + 'patch_size': 7, + 'kernel_size': 7 + }, + '1024/20': { + 'embed_dims': 1024, + 'depth': 20, + 'patch_size': 14, + 'kernel_size': 9 + }, + '1536/20': { + 'embed_dims': 1536, + 'depth': 20, + 'patch_size': 7, + 'kernel_size': 9 + }, + } + + def __init__(self, + arch='768/32', + in_channels=3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + out_indices=-1, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + essential_keys = { + 'embed_dims', 'depth', 'patch_size', 'kernel_size' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + + self.embed_dims = arch['embed_dims'] + self.depth = arch['depth'] + self.patch_size = arch['patch_size'] + self.kernel_size = arch['kernel_size'] + self.act = build_activation_layer(act_cfg) + + # check out indices and frozen stages + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.depth + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # Set stem layers + self.stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.embed_dims, + kernel_size=self.patch_size, + stride=self.patch_size), self.act, + build_norm_layer(norm_cfg, self.embed_dims)[1]) + + # Set conv2d according to torch version + convfunc = nn.Conv2d + if digit_version(torch.__version__) < digit_version('1.9.0'): + convfunc = Conv2dAdaptivePadding + + # Repetitions of ConvMixer Layer + self.stages = nn.Sequential(*[ + nn.Sequential( + Residual( + nn.Sequential( + convfunc( + self.embed_dims, + self.embed_dims, + self.kernel_size, + groups=self.embed_dims, + padding='same'), self.act, + build_norm_layer(norm_cfg, self.embed_dims)[1])), + nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=1), + self.act, + build_norm_layer(norm_cfg, self.embed_dims)[1]) + for _ in range(self.depth) + ]) + + self._freeze_stages() + + def forward(self, x): + x = self.stem(x) + outs = [] + for i, stage in enumerate(self.stages): + x = stage(x) + if i in self.out_indices: + outs.append(x) + + # x = self.pooling(x).flatten(1) + return tuple(outs) + + def train(self, mode=True): + super(ConvMixer, self).train(mode) + self._freeze_stages() + + def _freeze_stages(self): + for i in range(self.frozen_stages): + stage = self.stages[i] + stage.eval() + for param in stage.parameters(): + param.requires_grad = False diff --git a/mmpretrain/models/backbones/convnext.py b/mmpretrain/models/backbones/convnext.py new file mode 100644 index 0000000..6a954f5 --- /dev/null +++ b/mmpretrain/models/backbones/convnext.py @@ -0,0 +1,412 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from itertools import chain +from typing import Sequence + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import GRN, build_norm_layer +from .base_backbone import BaseBackbone + + +class ConvNeXtBlock(BaseModule): + """ConvNeXt Block. + + Args: + in_channels (int): The number of input channels. + dw_conv_cfg (dict): Config of depthwise convolution. + Defaults to ``dict(kernel_size=7, padding=3)``. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + mlp_ratio (float): The expansion ratio in both pointwise convolution. + Defaults to 4. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. More details can be found in the note. + Defaults to True. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-6. + + Note: + There are two equivalent implementations: + + 1. DwConv -> LayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv; + all outputs are in (N, C, H, W). + 2. DwConv -> LayerNorm -> Permute to (N, H, W, C) -> Linear -> GELU + -> Linear; Permute back + + As default, we use the second to align with the official repository. + And it may be slightly faster. + """ + + def __init__(self, + in_channels, + dw_conv_cfg=dict(kernel_size=7, padding=3), + norm_cfg=dict(type='LN2d', eps=1e-6), + act_cfg=dict(type='GELU'), + mlp_ratio=4., + linear_pw_conv=True, + drop_path_rate=0., + layer_scale_init_value=1e-6, + use_grn=False, + with_cp=False): + super().__init__() + self.with_cp = with_cp + + self.depthwise_conv = nn.Conv2d( + in_channels, in_channels, groups=in_channels, **dw_conv_cfg) + + self.linear_pw_conv = linear_pw_conv + self.norm = build_norm_layer(norm_cfg, in_channels) + + mid_channels = int(mlp_ratio * in_channels) + if self.linear_pw_conv: + # Use linear layer to do pointwise conv. + pw_conv = nn.Linear + else: + pw_conv = partial(nn.Conv2d, kernel_size=1) + + self.pointwise_conv1 = pw_conv(in_channels, mid_channels) + self.act = MODELS.build(act_cfg) + self.pointwise_conv2 = pw_conv(mid_channels, in_channels) + + if use_grn: + self.grn = GRN(mid_channels) + else: + self.grn = None + + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones((in_channels)), + requires_grad=True) if layer_scale_init_value > 0 else None + + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + shortcut = x + x = self.depthwise_conv(x) + + if self.linear_pw_conv: + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x, data_format='channel_last') + x = self.pointwise_conv1(x) + x = self.act(x) + if self.grn is not None: + x = self.grn(x, data_format='channel_last') + x = self.pointwise_conv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + else: + x = self.norm(x, data_format='channel_first') + x = self.pointwise_conv1(x) + x = self.act(x) + + if self.grn is not None: + x = self.grn(x, data_format='channel_first') + x = self.pointwise_conv2(x) + + if self.gamma is not None: + x = x.mul(self.gamma.view(1, -1, 1, 1)) + + x = shortcut + self.drop_path(x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class ConvNeXt(BaseBackbone): + """ConvNeXt v1&v2 backbone. + + A PyTorch implementation of `A ConvNet for the 2020s + `_ and + `ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders + `_ + + Modified from the `official repo + `_ + and `timm + `_. + + To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``ConvNeXt.arch_settings``. And if dict, it + should include the following two keys: + + - depths (list[int]): Number of blocks at each stage. + - channels (list[int]): The number of channels at each stage. + + Defaults to 'tiny'. + in_channels (int): Number of input image channels. Defaults to 3. + stem_patch_size (int): The size of one patch in the stem layer. + Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. Defaults to True. + use_grn (bool): Whether to add Global Response Normalization in the + blocks. Defaults to False. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-6. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + gap_before_final_norm (bool): Whether to globally average the feature + map before the final norm layer. In the official repo, it's only + used in classification task. Defaults to True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): Initialization config dict + """ # noqa: E501 + arch_settings = { + 'atto': { + 'depths': [2, 2, 6, 2], + 'channels': [40, 80, 160, 320] + }, + 'femto': { + 'depths': [2, 2, 6, 2], + 'channels': [48, 96, 192, 384] + }, + 'pico': { + 'depths': [2, 2, 6, 2], + 'channels': [64, 128, 256, 512] + }, + 'nano': { + 'depths': [2, 2, 8, 2], + 'channels': [80, 160, 320, 640] + }, + 'tiny': { + 'depths': [3, 3, 9, 3], + 'channels': [96, 192, 384, 768] + }, + 'small': { + 'depths': [3, 3, 27, 3], + 'channels': [96, 192, 384, 768] + }, + 'base': { + 'depths': [3, 3, 27, 3], + 'channels': [128, 256, 512, 1024] + }, + 'large': { + 'depths': [3, 3, 27, 3], + 'channels': [192, 384, 768, 1536] + }, + 'xlarge': { + 'depths': [3, 3, 27, 3], + 'channels': [256, 512, 1024, 2048] + }, + 'huge': { + 'depths': [3, 3, 27, 3], + 'channels': [352, 704, 1408, 2816] + } + } + + def __init__(self, + arch='tiny', + in_channels=3, + stem_patch_size=4, + norm_cfg=dict(type='LN2d', eps=1e-6), + act_cfg=dict(type='GELU'), + linear_pw_conv=True, + use_grn=False, + drop_path_rate=0., + layer_scale_init_value=1e-6, + out_indices=-1, + frozen_stages=0, + gap_before_final_norm=True, + with_cp=False, + init_cfg=[ + dict( + type='TruncNormal', + layer=['Conv2d', 'Linear'], + std=.02, + bias=0.), + dict( + type='Constant', layer=['LayerNorm'], val=1., + bias=0.), + ]): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'depths' in arch and 'channels' in arch, \ + f'The arch dict must have "depths" and "channels", ' \ + f'but got {list(arch.keys())}.' + + self.depths = arch['depths'] + self.channels = arch['channels'] + assert (isinstance(self.depths, Sequence) + and isinstance(self.channels, Sequence) + and len(self.depths) == len(self.channels)), \ + f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \ + 'should be both sequence with the same length.' + + self.num_stages = len(self.depths) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_final_norm = gap_before_final_norm + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + block_idx = 0 + + # 4 downsample layers between stages, including the stem layer. + self.downsample_layers = ModuleList() + stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.channels[0], + kernel_size=stem_patch_size, + stride=stem_patch_size), + build_norm_layer(norm_cfg, self.channels[0]), + ) + self.downsample_layers.append(stem) + + # 4 feature resolution stages, each consisting of multiple residual + # blocks + self.stages = nn.ModuleList() + + for i in range(self.num_stages): + depth = self.depths[i] + channels = self.channels[i] + + if i >= 1: + downsample_layer = nn.Sequential( + build_norm_layer(norm_cfg, self.channels[i - 1]), + nn.Conv2d( + self.channels[i - 1], + channels, + kernel_size=2, + stride=2), + ) + self.downsample_layers.append(downsample_layer) + + stage = Sequential(*[ + ConvNeXtBlock( + in_channels=channels, + drop_path_rate=dpr[block_idx + j], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + linear_pw_conv=linear_pw_conv, + layer_scale_init_value=layer_scale_init_value, + use_grn=use_grn, + with_cp=with_cp) for j in range(depth) + ]) + block_idx += depth + + self.stages.append(stage) + + if i in self.out_indices: + norm_layer = build_norm_layer(norm_cfg, channels) + self.add_module(f'norm{i}', norm_layer) + + self._freeze_stages() + + def forward(self, x): + outs = [] + for i, stage in enumerate(self.stages): + x = self.downsample_layers[i](x) + x = stage(x) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean([-2, -1], keepdim=True) + outs.append(norm_layer(gap).flatten(1)) + else: + outs.append(norm_layer(x)) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + downsample_layer = self.downsample_layers[i] + stage = self.stages[i] + downsample_layer.eval() + stage.eval() + for param in chain(downsample_layer.parameters(), + stage.parameters()): + param.requires_grad = False + + def train(self, mode=True): + super(ConvNeXt, self).train(mode) + self._freeze_stages() + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + """ + + max_layer_id = 12 if self.depths[-2] > 9 else 6 + + if not param_name.startswith(prefix): + # For subsequent module like head + return max_layer_id + 1, max_layer_id + 2 + + param_name = param_name[len(prefix):] + if param_name.startswith('downsample_layers'): + stage_id = int(param_name.split('.')[1]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1 or stage_id == 2: + layer_id = stage_id + 1 + else: # stage_id == 3: + layer_id = max_layer_id + + elif param_name.startswith('stages'): + stage_id = int(param_name.split('.')[1]) + block_id = int(param_name.split('.')[2]) + if stage_id == 0 or stage_id == 1: + layer_id = stage_id + 1 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + else: # stage_id == 3: + layer_id = max_layer_id + + # final norm layer + else: + layer_id = max_layer_id + 1 + + return layer_id, max_layer_id + 2 diff --git a/mmpretrain/models/backbones/cspnet.py b/mmpretrain/models/backbones/cspnet.py new file mode 100644 index 0000000..7492e97 --- /dev/null +++ b/mmpretrain/models/backbones/cspnet.py @@ -0,0 +1,679 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import to_ntuple +from .resnet import Bottleneck as ResNetBottleneck +from .resnext import Bottleneck as ResNeXtBottleneck + +eps = 1.0e-5 + + +class DarknetBottleneck(BaseModule): + """The basic bottleneck block used in Darknet. Each DarknetBottleneck + consists of two ConvModules and the input is added to the final output. + Each ConvModule is composed of Conv, BN, and LeakyReLU. The first convLayer + has filter size of 1x1 and the second one has the filter size of 3x3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the input/output channels of conv2. + Defaults to 4. + add_identity (bool): Whether to add identity to the out. + Defaults to True. + use_depthwise (bool): Whether to use depthwise separable convolution. + Defaults to False. + conv_cfg (dict): Config dict for convolution layer. Defaults to None, + which means using conv2d. + drop_path_rate (float): The ratio of the drop path layer. Default: 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN', eps=1e-5)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='Swish')``. + """ + + def __init__(self, + in_channels, + out_channels, + expansion=2, + add_identity=True, + use_depthwise=False, + conv_cfg=None, + drop_path_rate=0, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + init_cfg=None): + super().__init__(init_cfg) + hidden_channels = int(out_channels / expansion) + conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule + self.conv1 = ConvModule( + in_channels, + hidden_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = conv( + hidden_channels, + out_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.add_identity = \ + add_identity and in_channels == out_channels + + self.drop_path = DropPath(drop_prob=drop_path_rate + ) if drop_path_rate > eps else nn.Identity() + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.drop_path(out) + + if self.add_identity: + return out + identity + else: + return out + + +class CSPStage(BaseModule): + """Cross Stage Partial Stage. + + .. code:: text + + Downsample Convolution (optional) + | + | + Expand Convolution + | + | + Split to xa, xb + | \ + | \ + | blocks(xb) + | / + | / transition + | / + Concat xa, blocks(xb) + | + Transition Convolution + + Args: + block_fn (nn.module): The basic block function in the Stage. + in_channels (int): The input channels of the CSP layer. + out_channels (int): The output channels of the CSP layer. + has_downsampler (bool): Whether to add a downsampler in the stage. + Default: False. + down_growth (bool): Whether to expand the channels in the + downsampler layer of the stage. Default: False. + expand_ratio (float): The expand ratio to adjust the number of + channels of the expand conv layer. Default: 0.5 + bottle_ratio (float): Ratio to adjust the number of channels of the + hidden layer. Default: 0.5 + block_dpr (float): The ratio of the drop path layer in the + blocks of the stage. Default: 0. + num_blocks (int): Number of blocks. Default: 1 + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', inplace=True) + """ + + def __init__(self, + block_fn, + in_channels, + out_channels, + has_downsampler=True, + down_growth=False, + expand_ratio=0.5, + bottle_ratio=2, + num_blocks=1, + block_dpr=0, + block_args={}, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + init_cfg=None): + super().__init__(init_cfg) + # grow downsample channels to output channels + down_channels = out_channels if down_growth else in_channels + block_dpr = to_ntuple(num_blocks)(block_dpr) + + if has_downsampler: + self.downsample_conv = ConvModule( + in_channels=in_channels, + out_channels=down_channels, + kernel_size=3, + stride=2, + padding=1, + groups=32 if block_fn is ResNeXtBottleneck else 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.downsample_conv = nn.Identity() + + exp_channels = int(down_channels * expand_ratio) + self.expand_conv = ConvModule( + in_channels=down_channels, + out_channels=exp_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg if block_fn is DarknetBottleneck else None) + + assert exp_channels % 2 == 0, \ + 'The channel number before blocks must be divisible by 2.' + block_channels = exp_channels // 2 + blocks = [] + for i in range(num_blocks): + block_cfg = dict( + in_channels=block_channels, + out_channels=block_channels, + expansion=bottle_ratio, + drop_path_rate=block_dpr[i], + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **block_args) + blocks.append(block_fn(**block_cfg)) + self.blocks = Sequential(*blocks) + self.atfer_blocks_conv = ConvModule( + block_channels, + block_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.final_conv = ConvModule( + 2 * block_channels, + out_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + x = self.downsample_conv(x) + x = self.expand_conv(x) + + split = x.shape[1] // 2 + xa, xb = x[:, :split], x[:, split:] + + xb = self.blocks(xb) + xb = self.atfer_blocks_conv(xb).contiguous() + + x_final = torch.cat((xa, xb), dim=1) + return self.final_conv(x_final) + + +class CSPNet(BaseModule): + """The abstract CSP Network class. + + A Pytorch implementation of `CSPNet: A New Backbone that can Enhance + Learning Capability of CNN `_ + + This class is an abstract class because the Cross Stage Partial Network + (CSPNet) is a kind of universal network structure, and you + network block to implement networks like CSPResNet, CSPResNeXt and + CSPDarkNet. + + Args: + arch (dict): The architecture of the CSPNet. + It should have the following keys: + + - block_fn (Callable): A function or class to return a block + module, and it should accept at least ``in_channels``, + ``out_channels``, ``expansion``, ``drop_path_rate``, ``norm_cfg`` + and ``act_cfg``. + - in_channels (Tuple[int]): The number of input channels of each + stage. + - out_channels (Tuple[int]): The number of output channels of each + stage. + - num_blocks (Tuple[int]): The number of blocks in each stage. + - expansion_ratio (float | Tuple[float]): The expansion ratio in + the expand convolution of each stage. Defaults to 0.5. + - bottle_ratio (float | Tuple[float]): The expansion ratio of + blocks in each stage. Defaults to 2. + - has_downsampler (bool | Tuple[bool]): Whether to add a + downsample convolution in each stage. Defaults to True + - down_growth (bool | Tuple[bool]): Whether to expand the channels + in the downsampler layer of each stage. Defaults to False. + - block_args (dict | Tuple[dict], optional): The extra arguments to + the blocks in each stage. Defaults to None. + + stem_fn (Callable): A function or class to return a stem module. + And it should accept ``in_channels``. + in_channels (int): Number of input image channels. Defaults to 3. + out_indices (int | Sequence[int]): Output from which stages. + Defaults to -1, which means the last stage. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + conv_cfg (dict, optional): The config dict for conv layers in blocks. + Defaults to None, which means use Conv2d. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN', eps=1e-5)``. + act_cfg (dict): The config dict for activation functions. + Defaults to ``dict(type='LeakyReLU', inplace=True)``. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + init_cfg (dict, optional): The initialization settings. + Defaults to ``dict(type='Kaiming', layer='Conv2d'))``. + + Example: + >>> from functools import partial + >>> import torch + >>> import torch.nn as nn + >>> from mmpretrain.models import CSPNet + >>> from mmpretrain.models.backbones.resnet import Bottleneck + >>> + >>> # A simple example to build CSPNet. + >>> arch = dict( + ... block_fn=Bottleneck, + ... in_channels=[32, 64], + ... out_channels=[64, 128], + ... num_blocks=[3, 4] + ... ) + >>> stem_fn = partial(nn.Conv2d, out_channels=32, kernel_size=3) + >>> model = CSPNet(arch=arch, stem_fn=stem_fn, out_indices=(0, 1)) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> outs = model(inputs) + >>> for out in outs: + ... print(out.shape) + ... + (1, 64, 111, 111) + (1, 128, 56, 56) + """ + + def __init__(self, + arch, + stem_fn, + in_channels=3, + out_indices=-1, + frozen_stages=-1, + drop_path_rate=0., + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + norm_eval=False, + init_cfg=dict(type='Kaiming', layer='Conv2d')): + super().__init__(init_cfg=init_cfg) + self.arch = self.expand_arch(arch) + self.num_stages = len(self.arch['in_channels']) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + if frozen_stages not in range(-1, self.num_stages): + raise ValueError('frozen_stages must be in range(-1, ' + f'{self.num_stages}). But received ' + f'{frozen_stages}') + self.frozen_stages = frozen_stages + + self.stem = stem_fn(in_channels) + + stages = [] + depths = self.arch['num_blocks'] + dpr = torch.linspace(0, drop_path_rate, sum(depths)).split(depths) + + for i in range(self.num_stages): + stage_cfg = {k: v[i] for k, v in self.arch.items()} + csp_stage = CSPStage( + **stage_cfg, + block_dpr=dpr[i].tolist(), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + init_cfg=init_cfg) + stages.append(csp_stage) + self.stages = Sequential(*stages) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.stages) + index + assert 0 <= out_indices[i] <= len(self.stages), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + @staticmethod + def expand_arch(arch): + num_stages = len(arch['in_channels']) + + def to_tuple(x, name=''): + if isinstance(x, (list, tuple)): + assert len(x) == num_stages, \ + f'The length of {name} ({len(x)}) does not ' \ + f'equals to the number of stages ({num_stages})' + return tuple(x) + else: + return (x, ) * num_stages + + full_arch = {k: to_tuple(v, k) for k, v in arch.items()} + if 'block_args' not in full_arch: + full_arch['block_args'] = to_tuple({}) + return full_arch + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(CSPNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def forward(self, x): + outs = [] + + x = self.stem(x) + for i, stage in enumerate(self.stages): + x = stage(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + +@MODELS.register_module() +class CSPDarkNet(CSPNet): + """CSP-Darknet backbone used in YOLOv4. + + Args: + depth (int): Depth of CSP-Darknet. Default: 53. + in_channels (int): Number of input image channels. Default: 3. + out_indices (Sequence[int]): Output from which stages. + Default: (3, ). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Default: -1. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmpretrain.models import CSPDarkNet + >>> import torch + >>> model = CSPDarkNet(depth=53, out_indices=(0, 1, 2, 3, 4)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 416, 416) + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 64, 208, 208) + (1, 128, 104, 104) + (1, 256, 52, 52) + (1, 512, 26, 26) + (1, 1024, 13, 13) + """ + arch_settings = { + 53: + dict( + block_fn=DarknetBottleneck, + in_channels=(32, 64, 128, 256, 512), + out_channels=(64, 128, 256, 512, 1024), + num_blocks=(1, 2, 8, 8, 4), + expand_ratio=(2, 1, 1, 1, 1), + bottle_ratio=(2, 1, 1, 1, 1), + has_downsampler=True, + down_growth=True, + ), + } + + def __init__(self, + depth, + in_channels=3, + out_indices=(4, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + norm_eval=False, + init_cfg=dict( + type='Kaiming', + layer='Conv2d', + a=math.sqrt(5), + distribution='uniform', + mode='fan_in', + nonlinearity='leaky_relu')): + + assert depth in self.arch_settings, 'depth must be one of ' \ + f'{list(self.arch_settings.keys())}, but get {depth}.' + + super().__init__( + arch=self.arch_settings[depth], + stem_fn=self._make_stem_layer, + in_channels=in_channels, + out_indices=out_indices, + frozen_stages=frozen_stages, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + norm_eval=norm_eval, + init_cfg=init_cfg) + + def _make_stem_layer(self, in_channels): + """using a stride=1 conv as the stem in CSPDarknet.""" + # `stem_channels` equals to the `in_channels` in the first stage. + stem_channels = self.arch['in_channels'][0] + stem = ConvModule( + in_channels=in_channels, + out_channels=stem_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + return stem + + +@MODELS.register_module() +class CSPResNet(CSPNet): + """CSP-ResNet backbone. + + Args: + depth (int): Depth of CSP-ResNet. Default: 50. + out_indices (Sequence[int]): Output from which stages. + Default: (4, ). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Default: -1. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Example: + >>> from mmpretrain.models import CSPResNet + >>> import torch + >>> model = CSPResNet(depth=50, out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 416, 416) + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 128, 104, 104) + (1, 256, 52, 52) + (1, 512, 26, 26) + (1, 1024, 13, 13) + """ + arch_settings = { + 50: + dict( + block_fn=ResNetBottleneck, + in_channels=(64, 128, 256, 512), + out_channels=(128, 256, 512, 1024), + num_blocks=(3, 3, 5, 2), + expand_ratio=4, + bottle_ratio=2, + has_downsampler=(False, True, True, True), + down_growth=False), + } + + def __init__(self, + depth, + in_channels=3, + out_indices=(3, ), + frozen_stages=-1, + deep_stem=False, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-5), + act_cfg=dict(type='LeakyReLU', inplace=True), + norm_eval=False, + init_cfg=dict(type='Kaiming', layer='Conv2d')): + assert depth in self.arch_settings, 'depth must be one of ' \ + f'{list(self.arch_settings.keys())}, but get {depth}.' + self.deep_stem = deep_stem + + super().__init__( + arch=self.arch_settings[depth], + stem_fn=self._make_stem_layer, + in_channels=in_channels, + out_indices=out_indices, + frozen_stages=frozen_stages, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + norm_eval=norm_eval, + init_cfg=init_cfg) + + def _make_stem_layer(self, in_channels): + # `stem_channels` equals to the `in_channels` in the first stage. + stem_channels = self.arch['in_channels'][0] + if self.deep_stem: + stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + else: + stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + return stem + + +@MODELS.register_module() +class CSPResNeXt(CSPResNet): + """CSP-ResNeXt backbone. + + Args: + depth (int): Depth of CSP-ResNeXt. Default: 50. + out_indices (Sequence[int]): Output from which stages. + Default: (4, ). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Default: -1. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LeakyReLU', negative_slope=0.1). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Example: + >>> from mmpretrain.models import CSPResNeXt + >>> import torch + >>> model = CSPResNeXt(depth=50, out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 256, 56, 56) + (1, 512, 28, 28) + (1, 1024, 14, 14) + (1, 2048, 7, 7) + """ + arch_settings = { + 50: + dict( + block_fn=ResNeXtBottleneck, + in_channels=(64, 256, 512, 1024), + out_channels=(256, 512, 1024, 2048), + num_blocks=(3, 3, 5, 2), + expand_ratio=(4, 2, 2, 2), + bottle_ratio=4, + has_downsampler=(False, True, True, True), + down_growth=False, + # the base_channels is changed from 64 to 32 in CSPNet + block_args=dict(base_channels=32), + ), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/mmpretrain/models/backbones/davit.py b/mmpretrain/models/backbones/davit.py new file mode 100644 index 0000000..cf25e2e --- /dev/null +++ b/mmpretrain/models/backbones/davit.py @@ -0,0 +1,834 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence, Tuple + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.cnn.bricks import Conv2d +from mmcv.cnn.bricks.transformer import FFN, AdaptivePadding, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.utils import to_2tuple +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import ShiftWindowMSA + + +class DaViTWindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module for DaViT. + + The differences between DaViTWindowMSA & WindowMSA: + 1. Without relative position bias. + + Args: + embed_dims (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop (float, optional): Dropout ratio of attention weight. + Defaults to 0. + proj_drop (float, optional): Dropout ratio of output. Defaults to 0. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + init_cfg=None): + + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww, + Wh*Ww), value should be between (-inf, 0]. + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class ConvPosEnc(BaseModule): + """DaViT conv pos encode block. + + Args: + embed_dims (int): Number of input channels. + kernel_size (int): The kernel size of the first convolution. + Defaults to 3. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, embed_dims, kernel_size=3, init_cfg=None): + super(ConvPosEnc, self).__init__(init_cfg) + self.proj = Conv2d( + embed_dims, + embed_dims, + kernel_size, + stride=1, + padding=kernel_size // 2, + groups=embed_dims) + + def forward(self, x, size: Tuple[int, int]): + B, N, C = x.shape + H, W = size + assert N == H * W + + feat = x.transpose(1, 2).view(B, C, H, W) + feat = self.proj(feat) + feat = feat.flatten(2).transpose(1, 2) + x = x + feat + return x + + +class DaViTDownSample(BaseModule): + """DaViT down sampole block. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + conv_type (str): The type of convolution + to generate patch embedding. Default: "Conv2d". + kernel_size (int): The kernel size of the first convolution. + Defaults to 2. + stride (int): The stride of the second convluation module. + Defaults to 2. + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Defaults to "corner". + dilation (int): Dilation of the convolution layers. Defaults to 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_type='Conv2d', + kernel_size=2, + stride=2, + padding='same', + dilation=1, + bias=True, + norm_cfg=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.out_channels = out_channels + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adaptive_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of conv + padding = 0 + else: + self.adaptive_padding = None + padding = to_2tuple(padding) + + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, in_channels)[1] + else: + self.norm = None + + def forward(self, x, input_size): + if self.adaptive_padding: + x = self.adaptive_padding(x) + H, W = input_size + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + + x = self.norm(x) + x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() + + x = self.projection(x) + output_size = (x.size(2), x.size(3)) + x = x.flatten(2).transpose(1, 2) + return x, output_size + + +class ChannelAttention(BaseModule): + """DaViT channel attention. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, embed_dims, num_heads=8, qkv_bias=False, init_cfg=None): + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + self.head_dims = embed_dims // num_heads + self.scale = self.head_dims**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dims, embed_dims) + + def forward(self, x): + B, N, _ = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + k = k * self.scale + attention = k.transpose(-1, -2) @ v + attention = attention.softmax(dim=-1) + + x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) + x = self.proj(x) + return x + + +class ChannelBlock(BaseModule): + """DaViT channel attention block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + ffn_ratio=4., + qkv_bias=False, + drop_path=0., + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg) + self.with_cp = with_cp + + self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ChannelAttention( + embed_dims, num_heads=num_heads, qkv_bias=qkv_bias) + self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + **ffn_cfgs + } + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN(**_ffn_cfgs) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + x = self.cpe1(x, hw_shape) + identity = x + x = self.norm1(x) + x = self.attn(x) + x = x + identity + + x = self.cpe2(x, hw_shape) + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SpatialBlock(BaseModule): + """DaViT spatial attention block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(SpatialBlock, self).__init__(init_cfg) + self.with_cp = with_cp + + self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + _attn_cfgs = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'shift_size': 0, + 'window_size': window_size, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'qkv_bias': qkv_bias, + 'pad_small_map': pad_small_map, + 'window_msa': DaViTWindowMSA, + **attn_cfgs + } + self.attn = ShiftWindowMSA(**_attn_cfgs) + self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3) + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + **ffn_cfgs + } + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN(**_ffn_cfgs) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + x = self.cpe1(x, hw_shape) + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + x = x + identity + + x = self.cpe2(x, hw_shape) + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class DaViTBlock(BaseModule): + """DaViT block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(DaViTBlock, self).__init__(init_cfg) + self.spatial_block = SpatialBlock( + embed_dims, + num_heads, + window_size=window_size, + ffn_ratio=ffn_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path, + pad_small_map=pad_small_map, + attn_cfgs=attn_cfgs, + ffn_cfgs=ffn_cfgs, + norm_cfg=norm_cfg, + with_cp=with_cp) + self.channel_block = ChannelBlock( + embed_dims, + num_heads, + ffn_ratio=ffn_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path, + ffn_cfgs=ffn_cfgs, + norm_cfg=norm_cfg, + with_cp=False) + + def forward(self, x, hw_shape): + x = self.spatial_block(x, hw_shape) + x = self.channel_block(x, hw_shape) + + return x + + +class DaViTBlockSequence(BaseModule): + """Module with successive DaViT blocks and downsample layer. + + Args: + embed_dims (int): Number of input channels. + depth (int): Number of successive DaViT blocks. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + downsample (bool): Downsample the output of blocks by patch merging. + Defaults to False. + downsample_cfg (dict): The extra config of the patch merging layer. + Defaults to empty dict. + drop_paths (Sequence[float] | float): The drop path rate in each block. + Defaults to 0. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + depth, + num_heads, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + downsample=False, + downsample_cfg=dict(), + drop_paths=0., + block_cfgs=dict(), + with_cp=False, + pad_small_map=False, + init_cfg=None): + super().__init__(init_cfg) + + if not isinstance(drop_paths, Sequence): + drop_paths = [drop_paths] * depth + + if not isinstance(block_cfgs, Sequence): + block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] + + self.embed_dims = embed_dims + self.blocks = ModuleList() + for i in range(depth): + _block_cfg = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'window_size': window_size, + 'ffn_ratio': ffn_ratio, + 'qkv_bias': qkv_bias, + 'drop_path': drop_paths[i], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **block_cfgs[i] + } + block = DaViTBlock(**_block_cfg) + self.blocks.append(block) + + if downsample: + _downsample_cfg = { + 'in_channels': embed_dims, + 'out_channels': 2 * embed_dims, + 'norm_cfg': dict(type='LN'), + **downsample_cfg + } + self.downsample = DaViTDownSample(**_downsample_cfg) + else: + self.downsample = None + + def forward(self, x, in_shape, do_downsample=True): + for block in self.blocks: + x = block(x, in_shape) + + if self.downsample is not None and do_downsample: + x, out_shape = self.downsample(x, in_shape) + else: + out_shape = in_shape + return x, out_shape + + @property + def out_channels(self): + if self.downsample: + return self.downsample.out_channels + else: + return self.embed_dims + + +@MODELS.register_module() +class DaViT(BaseBackbone): + """DaViT. + + A PyTorch implement of : `DaViT: Dual Attention Vision Transformers + `_ + + Inspiration from + https://github.com/dingmyu/davit + + Args: + arch (str | dict): DaViT architecture. If use string, choose from + 'tiny', 'small', 'base' and 'large', 'huge', 'giant'. If use dict, + it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (List[int]): The number of heads in attention + modules of each stage. + + Defaults to 't'. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + in_channels (int): The num of input channels. Defaults to 3. + window_size (int): The height and width of the window. Defaults to 7. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + out_after_downsample (bool): Whether to output the feature map of a + stage after the following downsample layer. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + stage_cfgs (Sequence[dict] | dict): Extra config dict for each + stage. Defaults to an empty dict. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], { + 'embed_dims': 96, + 'depths': [1, 1, 3, 1], + 'num_heads': [3, 6, 12, 24] + }), + **dict.fromkeys(['s', 'small'], { + 'embed_dims': 96, + 'depths': [1, 1, 9, 1], + 'num_heads': [3, 6, 12, 24] + }), + **dict.fromkeys(['b', 'base'], { + 'embed_dims': 128, + 'depths': [1, 1, 9, 1], + 'num_heads': [4, 8, 16, 32] + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 192, + 'depths': [1, 1, 9, 1], + 'num_heads': [6, 12, 24, 48] + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 256, + 'depths': [1, 1, 9, 1], + 'num_heads': [8, 16, 32, 64] + }), + **dict.fromkeys( + ['g', 'giant'], { + 'embed_dims': 384, + 'depths': [1, 1, 12, 3], + 'num_heads': [12, 24, 48, 96] + }), + } + + def __init__(self, + arch='t', + patch_size=4, + in_channels=3, + window_size=7, + ffn_ratio=4., + qkv_bias=True, + drop_path_rate=0.1, + out_after_downsample=False, + pad_small_map=False, + norm_cfg=dict(type='LN'), + stage_cfgs=dict(), + frozen_stages=-1, + norm_eval=False, + out_indices=(3, ), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.num_layers = len(self.depths) + self.out_indices = out_indices + self.out_after_downsample = out_after_downsample + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + # stochastic depth decay rule + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + _patch_cfg = dict( + in_channels=in_channels, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=7, + stride=patch_size, + padding='same', + norm_cfg=dict(type='LN'), + ) + self.patch_embed = PatchEmbed(**_patch_cfg) + + self.stages = ModuleList() + embed_dims = [self.embed_dims] + for i, (depth, + num_heads) in enumerate(zip(self.depths, self.num_heads)): + if isinstance(stage_cfgs, Sequence): + stage_cfg = stage_cfgs[i] + else: + stage_cfg = deepcopy(stage_cfgs) + downsample = True if i < self.num_layers - 1 else False + _stage_cfg = { + 'embed_dims': embed_dims[-1], + 'depth': depth, + 'num_heads': num_heads, + 'window_size': window_size, + 'ffn_ratio': ffn_ratio, + 'qkv_bias': qkv_bias, + 'downsample': downsample, + 'drop_paths': dpr[:depth], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **stage_cfg + } + + stage = DaViTBlockSequence(**_stage_cfg) + self.stages.append(stage) + + dpr = dpr[depth:] + embed_dims.append(stage.out_channels) + + self.num_features = embed_dims[:-1] + + # add a norm layer for each output + for i in out_indices: + if norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, + self.num_features[i])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm{i}', norm_layer) + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + for i in self.out_indices: + if i <= self.frozen_stages: + for param in getattr(self, f'norm{i}').parameters(): + param.requires_grad = False + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage( + x, hw_shape, do_downsample=self.out_after_downsample) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + if stage.downsample is not None and not self.out_after_downsample: + x, hw_shape = stage.downsample(x, hw_shape) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/deit.py b/mmpretrain/models/backbones/deit.py new file mode 100644 index 0000000..9ae3408 --- /dev/null +++ b/mmpretrain/models/backbones/deit.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from .vision_transformer import VisionTransformer + + +@MODELS.register_module() +class DistilledVisionTransformer(VisionTransformer): + """Distilled Vision Transformer. + + A PyTorch implement of : `Training data-efficient image transformers & + distillation through attention `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'deit-base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: A tuple with the class token and the + distillation token. The shapes of both tensor are (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + num_extra_tokens = 2 # class token and distillation token + + def __init__(self, arch='deit-base', *args, **kwargs): + super(DistilledVisionTransformer, self).__init__( + arch=arch, + with_cls_token=True, + *args, + **kwargs, + ) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + x = x + self.resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'cls_token': + return x[:, 0], x[:, 1] + + return super()._format_output(x, hw) + + def init_weights(self): + super(DistilledVisionTransformer, self).init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + trunc_normal_(self.dist_token, std=0.02) diff --git a/mmpretrain/models/backbones/deit3.py b/mmpretrain/models/backbones/deit3.py new file mode 100644 index 0000000..acedabe --- /dev/null +++ b/mmpretrain/models/backbones/deit3.py @@ -0,0 +1,454 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import numpy as np +import torch +from mmcv.cnn import Linear, build_activation_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmengine.model import BaseModule, ModuleList, Sequential +from mmengine.utils import deprecated_api_warning +from torch import nn + +from mmpretrain.registry import MODELS +from ..utils import (LayerScale, MultiheadAttention, build_norm_layer, + resize_pos_embed, to_2tuple) +from .vision_transformer import VisionTransformer + + +class DeiT3FFN(BaseModule): + """FFN for DeiT3. + + The differences between DeiT3FFN & FFN: + 1. Use LayerScale. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + add_identity (bool, optional): Whether to add the + identity connection. Default: `True`. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + use_layer_scale (bool): Whether to use layer_scale in + DeiT3FFN. Defaults to True. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + @deprecated_api_warning( + { + 'dropout': 'ffn_drop', + 'add_residual': 'add_identity' + }, + cls_name='FFN') + def __init__(self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0., + dropout_layer=None, + add_identity=True, + use_layer_scale=True, + init_cfg=None, + **kwargs): + super().__init__(init_cfg) + assert num_fcs >= 2, 'num_fcs should be no less ' \ + f'than 2. got {num_fcs}.' + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + Sequential( + Linear(in_channels, feedforward_channels), self.activate, + nn.Dropout(ffn_drop))) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + if use_layer_scale: + self.gamma2 = LayerScale(embed_dims) + else: + self.gamma2 = nn.Identity() + + @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN') + def forward(self, x, identity=None): + """Forward function for `FFN`. + + The function would add x to the output tensor if residue is None. + """ + out = self.layers(x) + out = self.gamma2(out) + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class DeiT3TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in DeiT3. + + The differences between DeiT3TransformerEncoderLayer & + TransformerEncoderLayer: + 1. Use LayerScale. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension for FFNs + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + use_layer_scale (bool): Whether to use layer_scale in + DeiT3TransformerEncoderLayer. Defaults to True. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + use_layer_scale=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(DeiT3TransformerEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + use_layer_scale=use_layer_scale) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + self.ffn = DeiT3FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + use_layer_scale=use_layer_scale) + + def init_weights(self): + super(DeiT3TransformerEncoderLayer, self).init_weights() + for m in self.ffn.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln1(x), identity=x) + return x + + +@MODELS.register_module() +class DeiT3(VisionTransformer): + """DeiT3 backbone. + + A PyTorch implement of : `DeiT III: Revenge of the ViT + `_ + + The differences between DeiT3 & VisionTransformer: + + 1. Use LayerScale. + 2. Concat cls token after adding pos_embed. + + Args: + arch (str | dict): DeiT3 architecture. If use string, + choose from 'small', 'base', 'medium', 'large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + use_layer_scale (bool): Whether to use layer_scale in DeiT3. + Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 1536, + }), + **dict.fromkeys( + ['m', 'medium'], { + 'embed_dims': 512, + 'num_layers': 12, + 'num_heads': 8, + 'feedforward_channels': 2048, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120 + }), + } + num_extra_tokens = 1 # class token + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='cls_token', + with_cls_token=True, + use_layer_scale=True, + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super(VisionTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + use_layer_scale=use_layer_scale) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(DeiT3TransformerEncoderLayer(**_layer_cfg)) + + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=0) + x = self.drop_after_pos(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1]))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed( + state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + num_extra_tokens=0, # The cls token adding is after pos_embed + ) diff --git a/mmpretrain/models/backbones/densenet.py b/mmpretrain/models/backbones/densenet.py new file mode 100644 index 0000000..c9f0530 --- /dev/null +++ b/mmpretrain/models/backbones/densenet.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from itertools import chain +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import build_activation_layer, build_norm_layer +from torch.jit.annotations import List + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class DenseLayer(BaseBackbone): + """DenseBlock layers.""" + + def __init__(self, + in_channels, + growth_rate, + bn_size, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_rate=0., + memory_efficient=False): + super(DenseLayer, self).__init__() + + self.norm1 = build_norm_layer(norm_cfg, in_channels)[1] + self.conv1 = nn.Conv2d( + in_channels, + bn_size * growth_rate, + kernel_size=1, + stride=1, + bias=False) + self.act = build_activation_layer(act_cfg) + self.norm2 = build_norm_layer(norm_cfg, bn_size * growth_rate)[1] + self.conv2 = nn.Conv2d( + bn_size * growth_rate, + growth_rate, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.drop_rate = float(drop_rate) + self.memory_efficient = memory_efficient + + def bottleneck_fn(self, xs): + # type: (List[torch.Tensor]) -> torch.Tensor + concated_features = torch.cat(xs, 1) + bottleneck_output = self.conv1( + self.act(self.norm1(concated_features))) # noqa: T484 + return bottleneck_output + + # todo: rewrite when torchscript supports any + def any_requires_grad(self, x): + # type: (List[torch.Tensor]) -> bool + for tensor in x: + if tensor.requires_grad: + return True + return False + + # This decorator indicates to the compiler that a function or method + # should be ignored and replaced with the raising of an exception. + # Here this function is incompatible with torchscript. + @torch.jit.unused # noqa: T484 + def call_checkpoint_bottleneck(self, x): + # type: (List[torch.Tensor]) -> torch.Tensor + def closure(*xs): + return self.bottleneck_fn(xs) + + # Here use torch.utils.checkpoint to rerun a forward-pass during + # backward in bottleneck to save memories. + return cp.checkpoint(closure, *x) + + def forward(self, x): # noqa: F811 + # type: (List[torch.Tensor]) -> torch.Tensor + # assert input features is a list of Tensor + assert isinstance(x, list) + + if self.memory_efficient and self.any_requires_grad(x): + if torch.jit.is_scripting(): + raise Exception('Memory Efficient not supported in JIT') + bottleneck_output = self.call_checkpoint_bottleneck(x) + else: + bottleneck_output = self.bottleneck_fn(x) + + new_features = self.conv2(self.act(self.norm2(bottleneck_output))) + if self.drop_rate > 0: + new_features = F.dropout( + new_features, p=self.drop_rate, training=self.training) + return new_features + + +class DenseBlock(nn.Module): + """DenseNet Blocks.""" + + def __init__(self, + num_layers, + in_channels, + bn_size, + growth_rate, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_rate=0., + memory_efficient=False): + super(DenseBlock, self).__init__() + self.block = nn.ModuleList([ + DenseLayer( + in_channels + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + memory_efficient=memory_efficient) for i in range(num_layers) + ]) + + def forward(self, init_features): + features = [init_features] + for layer in self.block: + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + + +class DenseTransition(nn.Sequential): + """DenseNet Transition Layers.""" + + def __init__(self, + in_channels, + out_channels, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU')): + super(DenseTransition, self).__init__() + self.add_module('norm', build_norm_layer(norm_cfg, in_channels)[1]) + self.add_module('act', build_activation_layer(act_cfg)) + self.add_module( + 'conv', + nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, + bias=False)) + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +@MODELS.register_module() +class DenseNet(BaseBackbone): + """DenseNet. + + A PyTorch implementation of : `Densely Connected Convolutional Networks + `_ + + Modified from the `official repo + `_ + and `pytorch + `_. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``DenseNet.arch_settings``. And if dict, it + should include the following two keys: + + - growth_rate (int): Each layer of DenseBlock produce `k` feature + maps. Here refers `k` as the growth rate of the network. + - depths (list[int]): Number of repeated layers in each DenseBlock. + - init_channels (int): The output channels of stem layers. + + Defaults to '121'. + in_channels (int): Number of input image channels. Defaults to 3. + bn_size (int): Refers to channel expansion parameter of 1x1 + convolution layer. Defaults to 4. + drop_rate (float): Drop rate of Dropout Layer. Defaults to 0. + compression_factor (float): The reduction rate of transition layers. + Defaults to 0.5. + memory_efficient (bool): If True, uses checkpointing. Much more memory + efficient, but slower. Defaults to False. + See `"paper" `_. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation after each convolution. + Defaults to ``dict(type='ReLU')``. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): Initialization config dict. + """ + arch_settings = { + '121': { + 'growth_rate': 32, + 'depths': [6, 12, 24, 16], + 'init_channels': 64, + }, + '169': { + 'growth_rate': 32, + 'depths': [6, 12, 32, 32], + 'init_channels': 64, + }, + '201': { + 'growth_rate': 32, + 'depths': [6, 12, 48, 32], + 'init_channels': 64, + }, + '161': { + 'growth_rate': 48, + 'depths': [6, 12, 36, 24], + 'init_channels': 96, + }, + } + + def __init__(self, + arch='121', + in_channels=3, + bn_size=4, + drop_rate=0, + compression_factor=0.5, + memory_efficient=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + out_indices=-1, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + essential_keys = {'growth_rate', 'depths', 'init_channels'} + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + + self.growth_rate = arch['growth_rate'] + self.depths = arch['depths'] + self.init_channels = arch['init_channels'] + self.act = build_activation_layer(act_cfg) + + self.num_stages = len(self.depths) + + # check out indices and frozen stages + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # Set stem layers + self.stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.init_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False), + build_norm_layer(norm_cfg, self.init_channels)[1], self.act, + nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + # Repetitions of DenseNet Blocks + self.stages = nn.ModuleList() + self.transitions = nn.ModuleList() + + channels = self.init_channels + for i in range(self.num_stages): + depth = self.depths[i] + + stage = DenseBlock( + num_layers=depth, + in_channels=channels, + bn_size=bn_size, + growth_rate=self.growth_rate, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + memory_efficient=memory_efficient) + self.stages.append(stage) + channels += depth * self.growth_rate + + if i != self.num_stages - 1: + transition = DenseTransition( + in_channels=channels, + out_channels=math.floor(channels * compression_factor), + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + channels = math.floor(channels * compression_factor) + else: + # Final layers after dense block is just bn with act. + # Unlike the paper, the original repo also put this in + # transition layer, whereas torchvision take this out. + # We reckon this as transition layer here. + transition = nn.Sequential( + build_norm_layer(norm_cfg, channels)[1], + self.act, + ) + self.transitions.append(transition) + + self._freeze_stages() + + def forward(self, x): + x = self.stem(x) + outs = [] + for i in range(self.num_stages): + x = self.stages[i](x) + x = self.transitions[i](x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + downsample_layer = self.transitions[i] + stage = self.stages[i] + downsample_layer.eval() + stage.eval() + for param in chain(downsample_layer.parameters(), + stage.parameters()): + param.requires_grad = False + + def train(self, mode=True): + super(DenseNet, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/edgenext.py b/mmpretrain/models/backbones/edgenext.py new file mode 100644 index 0000000..ad4e768 --- /dev/null +++ b/mmpretrain/models/backbones/edgenext.py @@ -0,0 +1,398 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from itertools import chain +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import (ChannelMultiheadAttention, PositionEncodingFourier, + build_norm_layer) +from .base_backbone import BaseBackbone +from .convnext import ConvNeXtBlock + + +class SDTAEncoder(BaseModule): + """A PyTorch implementation of split depth-wise transpose attention (SDTA) + encoder. + + Inspiration from + https://github.com/mmaaz60/EdgeNeXt + Args: + in_channel (int): Number of input channels. + drop_path_rate (float): Stochastic depth dropout rate. + Defaults to 0. + layer_scale_init_value (float): Initial value of layer scale. + Defaults to 1e-6. + mlp_ratio (int): Number of channels ratio in the MLP. + Defaults to 4. + use_pos_emb (bool): Whether to use position encoding. + Defaults to True. + num_heads (int): Number of heads in the multihead attention. + Defaults to 8. + qkv_bias (bool): Whether to use bias in the multihead attention. + Defaults to True. + attn_drop (float): Dropout rate of the attention. + Defaults to 0. + proj_drop (float): Dropout rate of the projection. + Defaults to 0. + layer_scale_init_value (float): Initial value of layer scale. + Defaults to 1e-6. + norm_cfg (dict): Dictionary to construct normalization layer. + Defaults to ``dict(type='LN')``. + act_cfg (dict): Dictionary to construct activation layer. + Defaults to ``dict(type='GELU')``. + scales (int): Number of scales. Default to 1. + """ + + def __init__(self, + in_channel, + drop_path_rate=0., + layer_scale_init_value=1e-6, + mlp_ratio=4, + use_pos_emb=True, + num_heads=8, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + scales=1, + init_cfg=None): + super(SDTAEncoder, self).__init__(init_cfg=init_cfg) + conv_channels = max( + int(math.ceil(in_channel / scales)), + int(math.floor(in_channel // scales))) + self.conv_channels = conv_channels + self.num_convs = scales if scales == 1 else scales - 1 + + self.conv_modules = ModuleList() + for i in range(self.num_convs): + self.conv_modules.append( + nn.Conv2d( + conv_channels, + conv_channels, + kernel_size=3, + padding=1, + groups=conv_channels)) + + self.pos_embed = PositionEncodingFourier( + embed_dims=in_channel) if use_pos_emb else None + + self.norm_csa = build_norm_layer(norm_cfg, in_channel) + self.gamma_csa = nn.Parameter( + layer_scale_init_value * torch.ones(in_channel), + requires_grad=True) if layer_scale_init_value > 0 else None + self.csa = ChannelMultiheadAttention( + embed_dims=in_channel, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop) + + self.norm = build_norm_layer(norm_cfg, in_channel) + self.pointwise_conv1 = nn.Linear(in_channel, mlp_ratio * in_channel) + self.act = MODELS.build(act_cfg) + self.pointwise_conv2 = nn.Linear(mlp_ratio * in_channel, in_channel) + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones(in_channel), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + spx = torch.split(x, self.conv_channels, dim=1) + for i in range(self.num_convs): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.conv_modules[i](sp) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + x = torch.cat((out, spx[self.num_convs]), 1) + + # Channel Self-attention + B, C, H, W = x.shape + x = x.reshape(B, C, H * W).permute(0, 2, 1) + if self.pos_embed: + pos_encoding = self.pos_embed((B, H, W)) + pos_encoding = pos_encoding.reshape(B, -1, + x.shape[1]).permute(0, 2, 1) + x += pos_encoding + + x = x + self.drop_path(self.gamma_csa * self.csa(self.norm_csa(x))) + x = x.reshape(B, H, W, C) + + # Inverted Bottleneck + x = self.norm(x) + x = self.pointwise_conv1(x) + x = self.act(x) + x = self.pointwise_conv2(x) + + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) + + x = shortcut + self.drop_path(x) + + return x + + +@MODELS.register_module() +class EdgeNeXt(BaseBackbone): + """EdgeNeXt. + + A PyTorch implementation of: `EdgeNeXt: Efficiently Amalgamated + CNN-Transformer Architecture for Mobile Vision Applications + `_ + + Inspiration from + https://github.com/mmaaz60/EdgeNeXt + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architectures in ``EdgeNeXt.arch_settings``. + And if dict, it should include the following keys: + + - channels (list[int]): The number of channels at each stage. + - depths (list[int]): The number of blocks at each stage. + - num_heads (list[int]): The number of heads at each stage. + + Defaults to 'xxsmall'. + in_channels (int): The number of input channels. + Defaults to 3. + global_blocks (list[int]): The number of global blocks. + Defaults to [0, 1, 1, 1]. + global_block_type (list[str]): The type of global blocks. + Defaults to ['None', 'SDTA', 'SDTA', 'SDTA']. + drop_path_rate (float): Stochastic depth dropout rate. + Defaults to 0. + layer_scale_init_value (float): Initial value of layer scale. + Defaults to 1e-6. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. Defaults to False. + mlp_ratio (int): The number of channel ratio in MLP layers. + Defaults to 4. + conv_kernel_size (list[int]): The kernel size of convolutional layers + at each stage. Defaults to [3, 5, 7, 9]. + use_pos_embd_csa (list[bool]): Whether to use positional embedding in + Channel Self-Attention. Defaults to [False, True, False, False]. + use_pos_emebd_global (bool): Whether to use positional embedding for + whole network. Defaults to False. + d2_scales (list[int]): The number of channel groups used for SDTA at + each stage. Defaults to [2, 2, 3, 4]. + norm_cfg (dict): The config of normalization layer. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + gap_before_final_norm (bool): Whether to globally average the feature + map before the final norm layer. Defaults to True. + act_cfg (dict): The config of activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict, optional): Config for initialization. + Defaults to None. + """ + arch_settings = { + 'xxsmall': { # parameters: 1.3M + 'channels': [24, 48, 88, 168], + 'depths': [2, 2, 6, 2], + 'num_heads': [4, 4, 4, 4] + }, + 'xsmall': { # parameters: 2.3M + 'channels': [32, 64, 100, 192], + 'depths': [3, 3, 9, 3], + 'num_heads': [4, 4, 4, 4] + }, + 'small': { # parameters: 5.6M + 'channels': [48, 96, 160, 304], + 'depths': [3, 3, 9, 3], + 'num_heads': [8, 8, 8, 8] + }, + 'base': { # parameters: 18.51M + 'channels': [80, 160, 288, 584], + 'depths': [3, 3, 9, 3], + 'num_heads': [8, 8, 8, 8] + }, + } + + def __init__(self, + arch='xxsmall', + in_channels=3, + global_blocks=[0, 1, 1, 1], + global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'], + drop_path_rate=0., + layer_scale_init_value=1e-6, + linear_pw_conv=True, + mlp_ratio=4, + conv_kernel_sizes=[3, 5, 7, 9], + use_pos_embd_csa=[False, True, False, False], + use_pos_embd_global=False, + d2_scales=[2, 2, 3, 4], + norm_cfg=dict(type='LN2d', eps=1e-6), + out_indices=-1, + frozen_stages=0, + gap_before_final_norm=True, + act_cfg=dict(type='GELU'), + init_cfg=None): + super(EdgeNeXt, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in self.arch_settings, \ + f'Arch {arch} is not in default archs ' \ + f'{set(self.arch_settings)}' + self.arch_settings = self.arch_settings[arch] + elif isinstance(arch, dict): + essential_keys = {'channels', 'depths', 'num_heads'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.channels = self.arch_settings['channels'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.num_layers = len(self.depths) + self.use_pos_embd_global = use_pos_embd_global + + for g in global_block_type: + assert g in ['None', + 'SDTA'], f'Global block type {g} is not supported' + + self.num_stages = len(self.depths) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_final_norm = gap_before_final_norm + + if self.use_pos_embd_global: + self.pos_embed = PositionEncodingFourier( + embed_dims=self.channels[0]) + else: + self.pos_embed = None + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + + self.downsample_layers = ModuleList() + stem = nn.Sequential( + nn.Conv2d(in_channels, self.channels[0], kernel_size=4, stride=4), + build_norm_layer(norm_cfg, self.channels[0]), + ) + self.downsample_layers.append(stem) + + self.stages = ModuleList() + block_idx = 0 + for i in range(self.num_stages): + depth = self.depths[i] + channels = self.channels[i] + + if i >= 1: + downsample_layer = nn.Sequential( + build_norm_layer(norm_cfg, self.channels[i - 1]), + nn.Conv2d( + self.channels[i - 1], + channels, + kernel_size=2, + stride=2, + )) + self.downsample_layers.append(downsample_layer) + + stage_blocks = [] + for j in range(depth): + if j > depth - global_blocks[i] - 1: + stage_blocks.append( + SDTAEncoder( + in_channel=channels, + drop_path_rate=dpr[block_idx + j], + mlp_ratio=mlp_ratio, + scales=d2_scales[i], + use_pos_emb=use_pos_embd_csa[i], + num_heads=self.num_heads[i], + )) + else: + dw_conv_cfg = dict( + kernel_size=conv_kernel_sizes[i], + padding=conv_kernel_sizes[i] // 2, + ) + stage_blocks.append( + ConvNeXtBlock( + in_channels=channels, + dw_conv_cfg=dw_conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + linear_pw_conv=linear_pw_conv, + drop_path_rate=dpr[block_idx + j], + layer_scale_init_value=layer_scale_init_value, + )) + block_idx += depth + + stage_blocks = Sequential(*stage_blocks) + self.stages.append(stage_blocks) + + if i in self.out_indices: + out_norm_cfg = dict(type='LN') if self.gap_before_final_norm \ + else norm_cfg + norm_layer = build_norm_layer(out_norm_cfg, channels) + self.add_module(f'norm{i}', norm_layer) + + def init_weights(self) -> None: + # TODO: need to be implemented in the future + return super().init_weights() + + def forward(self, x): + outs = [] + for i, stage in enumerate(self.stages): + x = self.downsample_layers[i](x) + x = stage(x) + if self.pos_embed and i == 0: + B, _, H, W = x.shape + x += self.pos_embed((B, H, W)) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean([-2, -1], keepdim=True) + outs.append(norm_layer(gap.flatten(1))) + else: + # The output of LayerNorm2d may be discontiguous, which + # may cause some problem in the downstream tasks + outs.append(norm_layer(x).contiguous()) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + downsample_layer = self.downsample_layers[i] + stage = self.stages[i] + downsample_layer.eval() + stage.eval() + for param in chain(downsample_layer.parameters(), + stage.parameters()): + param.requires_grad = False + + def train(self, mode=True): + super(EdgeNeXt, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/efficientformer.py b/mmpretrain/models/backbones/efficientformer.py new file mode 100644 index 0000000..c2525c8 --- /dev/null +++ b/mmpretrain/models/backbones/efficientformer.py @@ -0,0 +1,606 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from typing import Optional, Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer, + build_norm_layer) +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import LayerScale +from .base_backbone import BaseBackbone +from .poolformer import Pooling + + +class AttentionWithBias(BaseModule): + """Multi-head Attention Module with attention_bias. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. Defaults to 8. + key_dim (int): The dimension of q, k. Defaults to 32. + attn_ratio (float): The dimension of v equals to + ``key_dim * attn_ratio``. Defaults to 4. + resolution (int): The height and width of attention_bias. + Defaults to 7. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + key_dim=32, + attn_ratio=4., + resolution=7, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.attn_ratio = attn_ratio + self.key_dim = key_dim + self.nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + h = self.dh + self.nh_kd * 2 + self.qkv = nn.Linear(embed_dims, h) + self.proj = nn.Linear(self.dh, embed_dims) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N)) + + @torch.no_grad() + def train(self, mode=True): + """change the mode of model.""" + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + """forward function. + + Args: + x (tensor): input features with shape of (B, N, C) + """ + B, N, _ = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + q, k, v = qkv.split([self.key_dim, self.key_dim, self.d], dim=-1) + + attn = ((q @ k.transpose(-2, -1)) * self.scale + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab)) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class Flat(nn.Module): + """Flat the input from (B, C, H, W) to (B, H*W, C).""" + + def __init__(self, ): + super().__init__() + + def forward(self, x: torch.Tensor): + x = x.flatten(2).transpose(1, 2) + return x + + +class LinearMlp(BaseModule): + """Mlp implemented with linear. + + The shape of input and output tensor are (B, N, C). + + Args: + in_features (int): Dimension of input features. + hidden_features (int): Dimension of hidden features. + out_features (int): Dimension of output features. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_cfg=dict(type='GELU'), + drop=0., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = build_activation_layer(act_cfg) + self.drop1 = nn.Dropout(drop) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop2 = nn.Dropout(drop) + + def forward(self, x): + """ + Args: + x (torch.Tensor): input tensor with shape (B, N, C). + + Returns: + torch.Tensor: output tensor with shape (B, N, C). + """ + x = self.drop1(self.act(self.fc1(x))) + x = self.drop2(self.fc2(x)) + return x + + +class ConvMlp(BaseModule): + """Mlp implemented with 1*1 convolutions. + + Args: + in_features (int): Dimension of input features. + hidden_features (int): Dimension of hidden features. + out_features (int): Dimension of output features. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + drop=0., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1] + self.norm2 = build_norm_layer(norm_cfg, out_features)[1] + + self.drop = nn.Dropout(drop) + + def forward(self, x): + """ + Args: + x (torch.Tensor): input tensor with shape (B, C, H, W). + + Returns: + torch.Tensor: output tensor with shape (B, C, H, W). + """ + + x = self.act(self.norm1(self.fc1(x))) + x = self.drop(x) + x = self.norm2(self.fc2(x)) + x = self.drop(x) + return x + + +class Meta3D(BaseModule): + """Meta Former block using 3 dimensions inputs, ``torch.Tensor`` with shape + (B, N, C).""" + + def __init__(self, + dim, + mlp_ratio=4., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + use_layer_scale=True, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + self.token_mixer = AttentionWithBias(dim) + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = LinearMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + if use_layer_scale: + self.ls1 = LayerScale(dim) + self.ls2 = LayerScale(dim) + else: + self.ls1, self.ls2 = nn.Identity(), nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x)))) + x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class Meta4D(BaseModule): + """Meta Former block using 4 dimensions inputs, ``torch.Tensor`` with shape + (B, C, H, W).""" + + def __init__(self, + dim, + pool_size=3, + mlp_ratio=4., + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + use_layer_scale=True, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.token_mixer = Pooling(pool_size=pool_size) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ConvMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + if use_layer_scale: + self.ls1 = LayerScale(dim, data_format='channels_first') + self.ls2 = LayerScale(dim, data_format='channels_first') + else: + self.ls1, self.ls2 = nn.Identity(), nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.ls1(self.token_mixer(x))) + x = x + self.drop_path(self.ls2(self.mlp(x))) + return x + + +def basic_blocks(in_channels, + out_channels, + index, + layers, + pool_size=3, + mlp_ratio=4., + act_cfg=dict(type='GELU'), + drop_rate=.0, + drop_path_rate=0., + use_layer_scale=True, + vit_num=1, + has_downsamper=False): + """generate EfficientFormer blocks for a stage.""" + blocks = [] + if has_downsamper: + blocks.append( + ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=True, + norm_cfg=dict(type='BN'), + act_cfg=None)) + if index == 3 and vit_num == layers[index]: + blocks.append(Flat()) + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / ( + sum(layers) - 1) + if index == 3 and layers[index] - block_idx <= vit_num: + blocks.append( + Meta3D( + out_channels, + mlp_ratio=mlp_ratio, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale, + )) + else: + blocks.append( + Meta4D( + out_channels, + pool_size=pool_size, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale)) + if index == 3 and layers[index] - block_idx - 1 == vit_num: + blocks.append(Flat()) + blocks = nn.Sequential(*blocks) + return blocks + + +@MODELS.register_module() +class EfficientFormer(BaseBackbone): + """EfficientFormer. + + A PyTorch implementation of EfficientFormer introduced by: + `EfficientFormer: Vision Transformers at MobileNet Speed `_ + + Modified from the `official repo + `. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``EfficientFormer.arch_settings``. And if dict, + it should include the following 4 keys: + + - layers (list[int]): Number of blocks at each stage. + - embed_dims (list[int]): The number of channels at each stage. + - downsamples (list[int]): Has downsample or not in the four stages. + - vit_num (int): The num of vit blocks in the last stage. + + Defaults to 'l1'. + + in_channels (int): The num of input channels. Defaults to 3. + pool_size (int): The pooling size of ``Meta4D`` blocks. Defaults to 3. + mlp_ratios (int): The dimension ratio of multi-head attention mechanism + in ``Meta4D`` blocks. Defaults to 3. + reshape_last_feat (bool): Whether to reshape the feature map from + (B, N, C) to (B, C, H, W) in the last stage, when the ``vit-num`` + in ``arch`` is not 0. Defaults to False. Usually set to True + in downstream tasks. + out_indices (Sequence[int]): Output from which stages. + Defaults to -1. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop_rate (float): Dropout rate. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + use_layer_scale (bool): Whether to use use_layer_scale in MetaFormer + block. Defaults to True. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + + Example: + >>> from mmpretrain.models import EfficientFormer + >>> import torch + >>> inputs = torch.rand((1, 3, 224, 224)) + >>> # build EfficientFormer backbone for classification task + >>> model = EfficientFormer(arch="l1") + >>> model.eval() + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 448, 49) + >>> # build EfficientFormer backbone for downstream task + >>> model = EfficientFormer( + >>> arch="l3", + >>> out_indices=(0, 1, 2, 3), + >>> reshape_last_feat=True) + >>> model.eval() + >>> level_outputs = model(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 56, 56) + (1, 128, 28, 28) + (1, 320, 14, 14) + (1, 512, 7, 7) + """ # noqa: E501 + + # --layers: [x,x,x,x], numbers of layers for the four stages + # --embed_dims: [x,x,x,x], embedding dims for the four stages + # --downsamples: [x,x,x,x], has downsample or not in the four stages + # --vit_num:(int), the num of vit blocks in the last stage + arch_settings = { + 'l1': { + 'layers': [3, 2, 6, 4], + 'embed_dims': [48, 96, 224, 448], + 'downsamples': [False, True, True, True], + 'vit_num': 1, + }, + 'l3': { + 'layers': [4, 4, 12, 6], + 'embed_dims': [64, 128, 320, 512], + 'downsamples': [False, True, True, True], + 'vit_num': 4, + }, + 'l7': { + 'layers': [6, 6, 18, 8], + 'embed_dims': [96, 192, 384, 768], + 'downsamples': [False, True, True, True], + 'vit_num': 8, + }, + } + + def __init__(self, + arch='l1', + in_channels=3, + pool_size=3, + mlp_ratios=4, + reshape_last_feat=False, + out_indices=-1, + frozen_stages=-1, + act_cfg=dict(type='GELU'), + drop_rate=0., + drop_path_rate=0., + use_layer_scale=True, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + self.num_extra_tokens = 0 # no cls_token, no dist_token + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + default_keys = set(self.arch_settings['l1'].keys()) + assert set(arch.keys()) == default_keys, \ + f'The arch dict must have {default_keys}, ' \ + f'but got {list(arch.keys())}.' + + self.layers = arch['layers'] + self.embed_dims = arch['embed_dims'] + self.downsamples = arch['downsamples'] + assert isinstance(self.layers, list) and isinstance( + self.embed_dims, list) and isinstance(self.downsamples, list) + assert len(self.layers) == len(self.embed_dims) == len( + self.downsamples) + + self.vit_num = arch['vit_num'] + self.reshape_last_feat = reshape_last_feat + + assert self.vit_num >= 0, "'vit_num' must be an integer " \ + 'greater than or equal to 0.' + assert self.vit_num <= self.layers[-1], ( + "'vit_num' must be an integer smaller than layer number") + + self._make_stem(in_channels, self.embed_dims[0]) + + # set the main block in network + network = [] + for i in range(len(self.layers)): + if i != 0: + in_channels = self.embed_dims[i - 1] + else: + in_channels = self.embed_dims[i] + out_channels = self.embed_dims[i] + stage = basic_blocks( + in_channels, + out_channels, + i, + self.layers, + pool_size=pool_size, + mlp_ratio=mlp_ratios, + act_cfg=act_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + vit_num=self.vit_num, + use_layer_scale=use_layer_scale, + has_downsamper=self.downsamples[i]) + network.append(stage) + + self.network = ModuleList(network) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + + self.out_indices = out_indices + for i_layer in self.out_indices: + if not self.reshape_last_feat and \ + i_layer == 3 and self.vit_num > 0: + layer = build_norm_layer( + dict(type='LN'), self.embed_dims[i_layer])[1] + else: + # use GN with 1 group as channel-first LN2D + layer = build_norm_layer( + dict(type='GN', num_groups=1), self.embed_dims[i_layer])[1] + + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.frozen_stages = frozen_stages + self._freeze_stages() + + def _make_stem(self, in_channels: int, stem_channels: int): + """make 2-ConvBNReLu stem layer.""" + self.patch_embed = Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=2, + padding=1, + bias=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + inplace=True)) + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + if idx == len(self.network) - 1: + N, _, H, W = x.shape + if self.downsamples[idx]: + H, W = H // 2, W // 2 + x = block(x) + if idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + + if idx == len(self.network) - 1 and x.dim() == 3: + # when ``vit-num`` > 0 and in the last stage, + # if `self.reshape_last_feat`` is True, reshape the + # features to `BCHW` format before the final normalization. + # if `self.reshape_last_feat`` is False, do + # normalization directly and permute the features to `BCN`. + if self.reshape_last_feat: + x = x.permute((0, 2, 1)).reshape(N, -1, H, W) + x_out = norm_layer(x) + else: + x_out = norm_layer(x).permute((0, 2, 1)) + else: + x_out = norm_layer(x) + + outs.append(x_out.contiguous()) + return tuple(outs) + + def forward(self, x): + # input embedding + x = self.patch_embed(x) + # through stages + x = self.forward_tokens(x) + return x + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + # Include both block and downsample layer. + module = self.network[i] + module.eval() + for param in module.parameters(): + param.requires_grad = False + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(EfficientFormer, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/efficientnet.py b/mmpretrain/models/backbones/efficientnet.py new file mode 100644 index 0000000..9ec7ee8 --- /dev/null +++ b/mmpretrain/models/backbones/efficientnet.py @@ -0,0 +1,410 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import ConvModule, DropPath +from mmengine.model import BaseModule, Sequential + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.models.utils import InvertedResidual, SELayer, make_divisible +from mmpretrain.registry import MODELS + + +class EdgeResidual(BaseModule): + """Edge Residual Block. + + Args: + in_channels (int): The input channels of this module. + out_channels (int): The output channels of this module. + mid_channels (int): The input channels of the second convolution. + kernel_size (int): The kernel size of the first convolution. + Defaults to 3. + stride (int): The stride of the first convolution. Defaults to 1. + se_cfg (dict, optional): Config dict for se layer. Defaults to None, + which means no se layer. + with_residual (bool): Use residual connection. Defaults to True. + conv_cfg (dict, optional): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_residual=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_path_rate=0., + with_cp=False, + init_cfg=None): + super(EdgeResidual, self).__init__(init_cfg=init_cfg) + assert stride in [1, 2] + self.with_cp = with_cp + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.with_se = se_cfg is not None + self.with_residual = ( + stride == 1 and in_channels == out_channels and with_residual) + + if self.with_se: + assert isinstance(se_cfg, dict) + + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.conv2 = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + out = self.conv1(out) + + if self.with_se: + out = self.se(out) + + out = self.conv2(out) + + if self.with_residual: + return x + self.drop_path(out) + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +def model_scaling(layer_setting, arch_setting): + """Scaling operation to the layer's parameters according to the + arch_setting.""" + # scale width + new_layer_setting = copy.deepcopy(layer_setting) + for layer_cfg in new_layer_setting: + for block_cfg in layer_cfg: + block_cfg[1] = make_divisible(block_cfg[1] * arch_setting[0], 8) + + # scale depth + split_layer_setting = [new_layer_setting[0]] + for layer_cfg in new_layer_setting[1:-1]: + tmp_index = [0] + for i in range(len(layer_cfg) - 1): + if layer_cfg[i + 1][1] != layer_cfg[i][1]: + tmp_index.append(i + 1) + tmp_index.append(len(layer_cfg)) + for i in range(len(tmp_index) - 1): + split_layer_setting.append(layer_cfg[tmp_index[i]:tmp_index[i + + 1]]) + split_layer_setting.append(new_layer_setting[-1]) + + num_of_layers = [len(layer_cfg) for layer_cfg in split_layer_setting[1:-1]] + new_layers = [ + int(math.ceil(arch_setting[1] * num)) for num in num_of_layers + ] + + merge_layer_setting = [split_layer_setting[0]] + for i, layer_cfg in enumerate(split_layer_setting[1:-1]): + if new_layers[i] <= num_of_layers[i]: + tmp_layer_cfg = layer_cfg[:new_layers[i]] + else: + tmp_layer_cfg = copy.deepcopy(layer_cfg) + [layer_cfg[-1]] * ( + new_layers[i] - num_of_layers[i]) + if tmp_layer_cfg[0][3] == 1 and i != 0: + merge_layer_setting[-1] += tmp_layer_cfg.copy() + else: + merge_layer_setting.append(tmp_layer_cfg.copy()) + merge_layer_setting.append(split_layer_setting[-1]) + + return merge_layer_setting + + +@MODELS.register_module() +class EfficientNet(BaseBackbone): + """EfficientNet backbone. + + Args: + arch (str): Architecture of efficientnet. Defaults to b0. + out_indices (Sequence[int]): Output from which stages. + Defaults to (6, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Defaults to dict(type='Swish'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + """ + + # Parameters to build layers. + # 'b' represents the architecture of normal EfficientNet family includes + # 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8'. + # 'e' represents the architecture of EfficientNet-EdgeTPU including 'es', + # 'em', 'el'. + # 6 parameters are needed to construct a layer, From left to right: + # - kernel_size: The kernel size of the block + # - out_channel: The number of out_channels of the block + # - se_ratio: The sequeeze ratio of SELayer. + # - stride: The stride of the block + # - expand_ratio: The expand_ratio of the mid_channels + # - block_type: -1: Not a block, 0: InvertedResidual, 1: EdgeResidual + layer_settings = { + 'b': [[[3, 32, 0, 2, 0, -1]], + [[3, 16, 4, 1, 1, 0]], + [[3, 24, 4, 2, 6, 0], + [3, 24, 4, 1, 6, 0]], + [[5, 40, 4, 2, 6, 0], + [5, 40, 4, 1, 6, 0]], + [[3, 80, 4, 2, 6, 0], + [3, 80, 4, 1, 6, 0], + [3, 80, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0], + [5, 112, 4, 1, 6, 0]], + [[5, 192, 4, 2, 6, 0], + [5, 192, 4, 1, 6, 0], + [5, 192, 4, 1, 6, 0], + [5, 192, 4, 1, 6, 0], + [3, 320, 4, 1, 6, 0]], + [[1, 1280, 0, 1, 0, -1]] + ], + 'e': [[[3, 32, 0, 2, 0, -1]], + [[3, 24, 0, 1, 3, 1]], + [[3, 32, 0, 2, 8, 1], + [3, 32, 0, 1, 8, 1]], + [[3, 48, 0, 2, 8, 1], + [3, 48, 0, 1, 8, 1], + [3, 48, 0, 1, 8, 1], + [3, 48, 0, 1, 8, 1]], + [[5, 96, 0, 2, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 96, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0], + [5, 144, 0, 1, 8, 0]], + [[5, 192, 0, 2, 8, 0], + [5, 192, 0, 1, 8, 0]], + [[1, 1280, 0, 1, 0, -1]] + ] + } # yapf: disable + + # Parameters to build different kinds of architecture. + # From left to right: scaling factor for width, scaling factor for depth, + # resolution. + arch_settings = { + 'b0': (1.0, 1.0, 224), + 'b1': (1.0, 1.1, 240), + 'b2': (1.1, 1.2, 260), + 'b3': (1.2, 1.4, 300), + 'b4': (1.4, 1.8, 380), + 'b5': (1.6, 2.2, 456), + 'b6': (1.8, 2.6, 528), + 'b7': (2.0, 3.1, 600), + 'b8': (2.2, 3.6, 672), + 'l2': (4.3, 5.3, 800), + 'es': (1.0, 1.0, 224), + 'em': (1.0, 1.1, 240), + 'el': (1.2, 1.4, 300) + } + + def __init__(self, + arch='b0', + drop_path_rate=0., + out_indices=(6, ), + frozen_stages=0, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='Swish'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + layer=['_BatchNorm', 'GroupNorm'], + val=1) + ]): + super(EfficientNet, self).__init__(init_cfg) + assert arch in self.arch_settings, \ + f'"{arch}" is not one of the arch_settings ' \ + f'({", ".join(self.arch_settings.keys())})' + self.arch_setting = self.arch_settings[arch] + # layer_settings of arch='l2' is 'b' + self.layer_setting = self.layer_settings['b' if arch == + 'l2' else arch[:1]] + for index in out_indices: + if index not in range(0, len(self.layer_setting)): + raise ValueError('the item in out_indices must in ' + f'range(0, {len(self.layer_setting)}). ' + f'But received {index}') + + if frozen_stages not in range(len(self.layer_setting) + 1): + raise ValueError('frozen_stages must be in range(0, ' + f'{len(self.layer_setting) + 1}). ' + f'But received {frozen_stages}') + self.drop_path_rate = drop_path_rate + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.layer_setting = model_scaling(self.layer_setting, + self.arch_setting) + block_cfg_0 = self.layer_setting[0][0] + block_cfg_last = self.layer_setting[-1][0] + self.in_channels = make_divisible(block_cfg_0[1], 8) + self.out_channels = block_cfg_last[1] + self.layers = nn.ModuleList() + self.layers.append( + ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=block_cfg_0[0], + stride=block_cfg_0[3], + padding=block_cfg_0[0] // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.make_layer() + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=block_cfg_last[0], + stride=block_cfg_last[3], + padding=block_cfg_last[0] // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def make_layer(self): + # Without the first and the final conv block. + layer_setting = self.layer_setting[1:-1] + + total_num_blocks = sum([len(x) for x in layer_setting]) + block_idx = 0 + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, total_num_blocks) + ] # stochastic depth decay rule + + for layer_cfg in layer_setting: + layer = [] + for i, block_cfg in enumerate(layer_cfg): + (kernel_size, out_channels, se_ratio, stride, expand_ratio, + block_type) = block_cfg + + mid_channels = int(self.in_channels * expand_ratio) + out_channels = make_divisible(out_channels, 8) + if se_ratio <= 0: + se_cfg = None + else: + se_cfg = dict( + channels=mid_channels, + ratio=expand_ratio * se_ratio, + divisor=1, + act_cfg=(self.act_cfg, dict(type='Sigmoid'))) + if block_type == 1: # edge tpu + if i > 0 and expand_ratio == 3: + with_residual = False + expand_ratio = 4 + else: + with_residual = True + mid_channels = int(self.in_channels * expand_ratio) + if se_cfg is not None: + se_cfg = dict( + channels=mid_channels, + ratio=se_ratio * expand_ratio, + divisor=1, + act_cfg=(self.act_cfg, dict(type='Sigmoid'))) + block = partial(EdgeResidual, with_residual=with_residual) + else: + block = InvertedResidual + layer.append( + block( + in_channels=self.in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + drop_path_rate=dpr[block_idx], + with_cp=self.with_cp)) + self.in_channels = out_channels + block_idx += 1 + self.layers.append(Sequential(*layer)) + + def forward(self, x): + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(EfficientNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmpretrain/models/backbones/efficientnet_v2.py b/mmpretrain/models/backbones/efficientnet_v2.py new file mode 100644 index 0000000..fec002a --- /dev/null +++ b/mmpretrain/models/backbones/efficientnet_v2.py @@ -0,0 +1,343 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import ConvModule, DropPath +from mmengine.model import Sequential +from torch import Tensor + +from mmpretrain.registry import MODELS +from ..utils import InvertedResidual as MBConv +from .base_backbone import BaseBackbone +from .efficientnet import EdgeResidual as FusedMBConv + + +class EnhancedConvModule(ConvModule): + """ConvModule with short-cut and droppath. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + has_skip (bool): Whether there is short-cut. Defaults to False. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise + False. Default: "auto". + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + def __init__(self, *args, has_skip=False, drop_path_rate=0, **kwargs): + super().__init__(*args, **kwargs) + self.has_skip = has_skip + if self.has_skip and (self.in_channels != self.out_channels + or self.stride != (1, 1)): + raise ValueError('the stride must be 1 and the `in_channels` and' + ' `out_channels` must be the same , when ' + '`has_skip` is True in `EnhancedConvModule` .') + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate else nn.Identity() + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + short_cut = x + x = super().forward(x, **kwargs) + if self.has_skip: + x = self.drop_path(x) + short_cut + return x + + +@MODELS.register_module() +class EfficientNetV2(BaseBackbone): + """EfficientNetV2 backbone. + + A PyTorch implementation of EfficientNetV2 introduced by: + `EfficientNetV2: Smaller Models and Faster Training + `_ + + Args: + arch (str): Architecture of efficientnetv2. Defaults to s. + in_channels (int): Number of input image channels. Defaults to 3. + drop_path_rate (float): The ratio of the stochastic depth. + Defaults to 0.0. + out_indices (Sequence[int]): Output from which stages. + Defaults to (-1, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Defaults to dict(type='Swish'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + """ + + # Parameters to build layers. From left to right: + # - repeat (int): The repeat number of the block in the layer + # - kernel_size (int): The kernel size of the layer + # - stride (int): The stride of the first block of the layer + # - expand_ratio (int, float): The expand_ratio of the mid_channels + # - in_channel (int): The number of in_channels of the layer + # - out_channel (int): The number of out_channels of the layer + # - se_ratio (float): The sequeeze ratio of SELayer. + # - block_type (int): -2: ConvModule, -1: EnhancedConvModule, + # 0: FusedMBConv, 1: MBConv + arch_settings = { + **dict.fromkeys(['small', 's'], [[2, 3, 1, 1, 24, 24, 0.0, -1], + [4, 3, 2, 4, 24, 48, 0.0, 0], + [4, 3, 2, 4, 48, 64, 0.0, 0], + [6, 3, 2, 4, 64, 128, 0.25, 1], + [9, 3, 1, 6, 128, 160, 0.25, 1], + [15, 3, 2, 6, 160, 256, 0.25, 1], + [1, 1, 1, 1, 256, 1280, 0.0, -2]]), + **dict.fromkeys(['m', 'medium'], [[3, 3, 1, 1, 24, 24, 0.0, -1], + [5, 3, 2, 4, 24, 48, 0.0, 0], + [5, 3, 2, 4, 48, 80, 0.0, 0], + [7, 3, 2, 4, 80, 160, 0.25, 1], + [14, 3, 1, 6, 160, 176, 0.25, 1], + [18, 3, 2, 6, 176, 304, 0.25, 1], + [5, 3, 1, 6, 304, 512, 0.25, 1], + [1, 1, 1, 1, 512, 1280, 0.0, -2]]), + **dict.fromkeys(['l', 'large'], [[4, 3, 1, 1, 32, 32, 0.0, -1], + [7, 3, 2, 4, 32, 64, 0.0, 0], + [7, 3, 2, 4, 64, 96, 0.0, 0], + [10, 3, 2, 4, 96, 192, 0.25, 1], + [19, 3, 1, 6, 192, 224, 0.25, 1], + [25, 3, 2, 6, 224, 384, 0.25, 1], + [7, 3, 1, 6, 384, 640, 0.25, 1], + [1, 1, 1, 1, 640, 1280, 0.0, -2]]), + **dict.fromkeys(['xl'], [[4, 3, 1, 1, 32, 32, 0.0, -1], + [8, 3, 2, 4, 32, 64, 0.0, 0], + [8, 3, 2, 4, 64, 96, 0.0, 0], + [16, 3, 2, 4, 96, 192, 0.25, 1], + [24, 3, 1, 6, 192, 256, 0.25, 1], + [32, 3, 2, 6, 256, 512, 0.25, 1], + [8, 3, 1, 6, 512, 640, 0.25, 1], + [1, 1, 1, 1, 640, 1280, 0.0, -2]]), + **dict.fromkeys(['b0'], [[1, 3, 1, 1, 32, 16, 0.0, -1], + [2, 3, 2, 4, 16, 32, 0.0, 0], + [2, 3, 2, 4, 32, 48, 0.0, 0], + [3, 3, 2, 4, 48, 96, 0.25, 1], + [5, 3, 1, 6, 96, 112, 0.25, 1], + [8, 3, 2, 6, 112, 192, 0.25, 1], + [1, 1, 1, 1, 192, 1280, 0.0, -2]]), + **dict.fromkeys(['b1'], [[2, 3, 1, 1, 32, 16, 0.0, -1], + [3, 3, 2, 4, 16, 32, 0.0, 0], + [3, 3, 2, 4, 32, 48, 0.0, 0], + [4, 3, 2, 4, 48, 96, 0.25, 1], + [6, 3, 1, 6, 96, 112, 0.25, 1], + [9, 3, 2, 6, 112, 192, 0.25, 1], + [1, 1, 1, 1, 192, 1280, 0.0, -2]]), + **dict.fromkeys(['b2'], [[2, 3, 1, 1, 32, 16, 0.0, -1], + [3, 3, 2, 4, 16, 32, 0.0, 0], + [3, 3, 2, 4, 32, 56, 0.0, 0], + [4, 3, 2, 4, 56, 104, 0.25, 1], + [6, 3, 1, 6, 104, 120, 0.25, 1], + [10, 3, 2, 6, 120, 208, 0.25, 1], + [1, 1, 1, 1, 208, 1408, 0.0, -2]]), + **dict.fromkeys(['b3'], [[2, 3, 1, 1, 40, 16, 0.0, -1], + [3, 3, 2, 4, 16, 40, 0.0, 0], + [3, 3, 2, 4, 40, 56, 0.0, 0], + [5, 3, 2, 4, 56, 112, 0.25, 1], + [7, 3, 1, 6, 112, 136, 0.25, 1], + [12, 3, 2, 6, 136, 232, 0.25, 1], + [1, 1, 1, 1, 232, 1536, 0.0, -2]]) + } + + def __init__(self, + arch: str = 's', + in_channels: int = 3, + drop_path_rate: float = 0., + out_indices: Sequence[int] = (-1, ), + frozen_stages: int = 0, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.1), + act_cfg=dict(type='Swish'), + norm_eval: bool = False, + with_cp: bool = False, + init_cfg=[ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + layer=['_BatchNorm', 'GroupNorm'], + val=1) + ]): + super(EfficientNetV2, self).__init__(init_cfg) + assert arch in self.arch_settings, \ + f'"{arch}" is not one of the arch_settings ' \ + f'({", ".join(self.arch_settings.keys())})' + self.arch = self.arch_settings[arch] + if frozen_stages not in range(len(self.arch) + 1): + raise ValueError('frozen_stages must be in range(0, ' + f'{len(self.arch)}), but get {frozen_stages}') + self.drop_path_rate = drop_path_rate + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.layers = nn.ModuleList() + assert self.arch[-1][-1] == -2, \ + f'the last block_type of `arch_setting` must be -2 ,' \ + f'but get `{self.arch[-1][-1]}`' + self.in_channels = in_channels + self.out_channels = self.arch[-1][5] + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.make_layers() + + # there len(slef.arch) + 2 layers in the backbone + # including: the first + len(self.arch) layers + the last + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.layers) + index + assert 0 <= out_indices[i] <= len(self.layers), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + def make_layers(self, ): + # make the first layer + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=self.arch[0][4], + kernel_size=3, + stride=2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + in_channels = self.arch[0][4] + layer_setting = self.arch[:-1] + + total_num_blocks = sum([x[0] for x in layer_setting]) + block_idx = 0 + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, total_num_blocks) + ] # stochastic depth decay rule + + for layer_cfg in layer_setting: + layer = [] + (repeat, kernel_size, stride, expand_ratio, _, out_channels, + se_ratio, block_type) = layer_cfg + for i in range(repeat): + stride = stride if i == 0 else 1 + if block_type == -1: + has_skip = stride == 1 and in_channels == out_channels + droppath_rate = dpr[block_idx] if has_skip else 0.0 + layer.append( + EnhancedConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + has_skip=has_skip, + drop_path_rate=droppath_rate, + stride=stride, + padding=1, + conv_cfg=None, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + in_channels = out_channels + else: + mid_channels = int(in_channels * expand_ratio) + se_cfg = None + if block_type != 0 and se_ratio > 0: + se_cfg = dict( + channels=mid_channels, + ratio=expand_ratio * (1.0 / se_ratio), + divisor=1, + act_cfg=(self.act_cfg, dict(type='Sigmoid'))) + block = FusedMBConv if block_type == 0 else MBConv + conv_cfg = self.conv_cfg if stride == 2 else None + layer.append( + block( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + drop_path_rate=dpr[block_idx], + with_cp=self.with_cp)) + in_channels = out_channels + block_idx += 1 + self.layers.append(Sequential(*layer)) + + # make the last layer + self.layers.append( + ConvModule( + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=self.arch[-1][1], + stride=self.arch[-1][2], + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, x: Tensor) -> Tuple[Tensor]: + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(EfficientNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmpretrain/models/backbones/hivit.py b/mmpretrain/models/backbones/hivit.py new file mode 100644 index 0000000..981cbf8 --- /dev/null +++ b/mmpretrain/models/backbones/hivit.py @@ -0,0 +1,656 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer, to_2tuple +from .base_backbone import BaseBackbone + + +class Mlp(nn.Module): + """MLP block. + + Args: + in_features (int): Number of input dims. + hidden_features (int): Number of hidden dims. + out_feature (int): Number of out dims. + act_layer: MLP activation layer. + drop (float): MLP dropout rate. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + """Attention. + + Args: + input size (int): Input size. + dim (int): Number of input dims. + num_heads (int): Number of attention heads. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + proj_drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + rpe (bool): If True, add relative position embedding to + the patch embedding. + """ + + def __init__(self, + input_size, + dim, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + rpe=True): + super().__init__() + self.input_size = input_size + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * input_size - 1) * + (2 * input_size - 1), num_heads)) if rpe else None + if rpe: + coords_h = torch.arange(input_size) + coords_w = torch.arange(input_size) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += input_size - 1 + relative_coords[:, :, 1] += input_size - 1 + relative_coords[:, :, 0] *= 2 * input_size - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer('relative_position_index', + relative_position_index) + + trunc_normal_(self.relative_position_bias_table, std=.02) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, rpe_index=None, mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if rpe_index is not None: + rpe_index = self.relative_position_index.view(-1) + S = int(math.sqrt(rpe_index.size(-1))) + relative_position_bias = self.relative_position_bias_table[ + rpe_index].view(-1, S, S, self.num_heads) + relative_position_bias = relative_position_bias.permute( + 0, 3, 1, 2).contiguous() + attn = attn + relative_position_bias + if mask is not None: + mask = mask.bool() + attn = attn.masked_fill(~mask[:, None, None, :], float('-inf')) + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class BlockWithRPE(nn.Module): + """HiViT block. + + Args: + input_size (int): Input size. + dim (int): Number of input dims. + num_heads (int): Number of attention heads. + mlp_ratio (int): Ratio of MLP hidden dim to embedding dim. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + rpe (bool): If True, add relative position embedding to + the patch embedding. + layer_scale_init_value (float): Layer-scale init values. Defaults to 0. + act_layer: MLP activation layer. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + """ + + def __init__(self, + input_size, + dim, + num_heads=0., + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + rpe=True, + layer_scale_init_value=0.0, + act_layer=nn.GELU, + norm_cfg=dict(type='LN')): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + + with_attn = num_heads > 0. + + self.norm1 = build_norm_layer(norm_cfg, dim) if with_attn else None + self.attn = Attention( + input_size, + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + rpe=rpe, + ) if with_attn else None + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = build_norm_layer(norm_cfg, dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + if layer_scale_init_value > 0: + self.gamma_1 = nn.Parameter( + layer_scale_init_value * torch.ones( + (dim)), requires_grad=True) if with_attn else None + self.gamma_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rpe_index=None, mask=None): + if self.attn is not None: + if self.gamma_1 is not None: + x = x + self.drop_path( + self.gamma_1 * self.attn(self.norm1(x), rpe_index, mask)) + else: + x = x + self.drop_path( + self.attn(self.norm1(x), rpe_index, mask)) + if self.gamma_2 is not None: + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """PatchEmbed for HiViT. + + Args: + img_size (int): Input image size. + patch_size (int): Patch size. Defaults to 16. + inner_patches (int): Inner patch. Defaults to 4. + in_chans (int): Number of image input channels. + embed_dim (int): Transformer embedding dimension. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + kernel_size (int): Kernel size. + pad_size (int): Pad size. + """ + + def __init__(self, + img_size=224, + patch_size=16, + inner_patches=4, + in_chans=3, + embed_dim=128, + norm_cfg=None, + kernel_size=None, + pad_size=None): + super().__init__() + img_size = to_2tuple(img_size) if not isinstance(img_size, + tuple) else img_size + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] + self.img_size = img_size + self.patch_size = patch_size + self.inner_patches = inner_patches + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + conv_size = [size // inner_patches for size in patch_size] + kernel_size = kernel_size or conv_size + pad_size = pad_size or 0 + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=kernel_size, + stride=conv_size, + padding=pad_size) + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + patches_resolution = (H // self.patch_size[0], W // self.patch_size[1]) + num_patches = patches_resolution[0] * patches_resolution[1] + x = self.proj(x).view( + B, + -1, + patches_resolution[0], + self.inner_patches, + patches_resolution[1], + self.inner_patches, + ).permute(0, 2, 4, 3, 5, 1).reshape(B, num_patches, self.inner_patches, + self.inner_patches, -1) + if self.norm is not None: + x = self.norm(x) + return x + + +class PatchMerge(nn.Module): + """PatchMerge for HiViT. + + Args: + dim (int): Number of input channels. + norm_cfg (dict): Config dict for normalization layer. + """ + + def __init__(self, dim, norm_cfg): + super().__init__() + self.norm = build_norm_layer(norm_cfg, dim * 4) + self.reduction = nn.Linear(dim * 4, dim * 2, bias=False) + + def forward(self, x, *args, **kwargs): + is_main_stage = len(x.shape) == 3 + if is_main_stage: + B, N, C = x.shape + S = int(math.sqrt(N)) + x = x.reshape(B, S // 2, 2, S // 2, 2, C) \ + .permute(0, 1, 3, 2, 4, 5) \ + .reshape(B, -1, 2, 2, C) + x0 = x[..., 0::2, 0::2, :] + x1 = x[..., 1::2, 0::2, :] + x2 = x[..., 0::2, 1::2, :] + x3 = x[..., 1::2, 1::2, :] + + x = torch.cat([x0, x1, x2, x3], dim=-1) + x = self.norm(x) + x = self.reduction(x) + + if is_main_stage: + x = x[:, :, 0, 0, :] + return x + + +@MODELS.register_module() +class HiViT(BaseBackbone): + """HiViT. + + A PyTorch implement of: `HiViT: A Simple and More Efficient Design + of Hierarchical Vision Transformer `_. + + Args: + arch (str | dict): Swin Transformer architecture. If use string, choose + from 'tiny', 'small', and'base'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (int): The number of heads in attention + modules of each stage. + + Defaults to 'tiny'. + img_size (int): Input image size. + patch_size (int): Patch size. Defaults to 16. + inner_patches (int): Inner patch. Defaults to 4. + in_chans (int): Number of image input channels. + embed_dim (int): Transformer embedding dimension. + depths (list[int]): Number of successive HiViT blocks. + num_heads (int): Number of attention heads. + stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim + in the first two stages. + mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in + the last stage. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + ape (bool): If True, add absolute position embedding to + the patch embedding. + rpe (bool): If True, add relative position embedding to + the patch embedding. + patch_norm (bool): If True, use norm_cfg for normalization layer. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + kernel_size (int): Kernel size. + pad_size (int): Pad size. + layer_scale_init_value (float): Layer-scale init values. Defaults to 0. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 384, + 'depths': [1, 1, 10], + 'num_heads': 6}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 384, + 'depths': [2, 2, 20], + 'num_heads': 6}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 512, + 'depths': [2, 2, 24], + 'num_heads': 8}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': 768, + 'depths': [2, 2, 40], + 'num_heads': 12}), + } # yapf: disable + + num_extra_tokens = 0 + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + inner_patches=4, + in_chans=3, + stem_mlp_ratio=3., + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.0, + norm_cfg=dict(type='LN'), + out_indices=[23], + ape=True, + rpe=False, + patch_norm=True, + frozen_stages=-1, + kernel_size=None, + pad_size=None, + layer_scale_init_value=0.0, + init_cfg=None): + super(HiViT, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + + self.num_stages = len(self.depths) + self.ape = ape + self.rpe = rpe + self.patch_size = patch_size + self.num_features = self.embed_dims + self.mlp_ratio = mlp_ratio + self.num_main_blocks = self.depths[-1] + self.out_indices = out_indices + self.out_indices[-1] = self.depths[-1] - 1 + + img_size = to_2tuple(img_size) if not isinstance(img_size, + tuple) else img_size + + embed_dim = self.embed_dims // 2**(self.num_stages - 1) + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + inner_patches=inner_patches, + in_chans=in_chans, + embed_dim=embed_dim, + norm_cfg=norm_cfg if patch_norm else None, + kernel_size=kernel_size, + pad_size=pad_size) + num_patches = self.patch_embed.num_patches + Hp, Wp = self.patch_embed.patches_resolution + + if rpe: + assert Hp == Wp, 'If you use relative position, make sure H == W ' + 'of input size' + + # absolute position embedding + if ape: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.num_features)) + trunc_normal_(self.pos_embed, std=.02) + if rpe: + # get pair-wise relative position index for each token inside the + # window + coords_h = torch.arange(Hp) + coords_w = torch.arange(Wp) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += Hp - 1 + relative_coords[:, :, 1] += Wp - 1 + relative_coords[:, :, 0] *= 2 * Wp - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer('relative_position_index', + relative_position_index) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = iter( + x.item() + for x in torch.linspace(0, drop_path_rate, + sum(self.depths) + sum(self.depths[:-1]))) + + # build blocks + self.blocks = nn.ModuleList() + for stage_i, stage_depth in enumerate(self.depths): + is_main_stage = embed_dim == self.num_features + nhead = self.num_heads if is_main_stage else 0 + ratio = mlp_ratio if is_main_stage else stem_mlp_ratio + # every block not in main stage includes two mlp blocks + stage_depth = stage_depth if is_main_stage else stage_depth * 2 + for _ in range(stage_depth): + self.blocks.append( + BlockWithRPE( + Hp, + embed_dim, + nhead, + ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=next(dpr), + rpe=rpe, + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value, + )) + if stage_i + 1 < self.num_stages: + self.blocks.append(PatchMerge(embed_dim, norm_cfg)) + embed_dim *= 2 + + self.frozen_stages = frozen_stages + if self.frozen_stages > 0: + self._freeze_stages() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, h, w): + npatch = x.shape[1] + N = self.pos_embed.shape[1] + if npatch == N and w == h: + return self.pos_embed + patch_pos_embed = self.pos_embed + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), + dim).permute(0, 3, 1, 2), + scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(h0) == patch_pos_embed.shape[-2] and int( + w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, x): + B, C, H, W = x.shape + Hp, Wp = H // self.patch_size, W // self.patch_size + + x = self.patch_embed(x) + + outs = [] + for i, blk in enumerate(self.blocks[:-self.num_main_blocks]): + x = blk(x) + if i in self.out_indices: + x = x.reshape(B, Hp, Wp, *x.shape[-3:]).permute( + 0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * x.shape[-3], + Wp * x.shape[-2]).contiguous() + outs.append(x) + + x = x[..., 0, 0, :] + if self.ape: + x = x + self.interpolate_pos_encoding(x, H, W) + x = self.pos_drop(x) + + rpe_index = True if self.rpe else None + + for i, blk in enumerate(self.blocks[-self.num_main_blocks:]): + x = blk(x, rpe_index) + if i in self.out_indices: + x = x.transpose(1, 2).view(B, -1, Hp, Wp).contiguous() + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.pos_drop.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + for param in self.fc_norm.parameters(): + param.requires_grad = False + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + self.num_layers = len(self.blocks) + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in 'pos_embed': + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/hornet.py b/mmpretrain/models/backbones/hornet.py new file mode 100644 index 0000000..460f2dc --- /dev/null +++ b/mmpretrain/models/backbones/hornet.py @@ -0,0 +1,500 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from official impl at https://github.com/raoyongming/HorNet. +try: + import torch.fft + fft = True +except ImportError: + fft = None + +import copy +from functools import partial +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from mmcv.cnn.bricks import DropPath + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import LayerScale + + +def get_dwconv(dim, kernel_size, bias=True): + """build a pepth-wise convolution.""" + return nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + bias=bias, + groups=dim) + + +class HorNetLayerNorm(nn.Module): + """An implementation of LayerNorm of HorNet. + + The differences between HorNetLayerNorm & torch LayerNorm: + 1. Supports two data formats channels_last or channels_first. + Args: + normalized_shape (int or list or torch.Size): input shape from an + expected input of size. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-5. + data_format (str): The ordering of the dimensions in the inputs. + channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with + shape (batch_size, channels, height, width). + Defaults to 'channels_last'. + """ + + def __init__(self, + normalized_shape, + eps=1e-6, + data_format='channels_last'): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ['channels_last', 'channels_first']: + raise ValueError( + 'data_format must be channels_last or channels_first') + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == 'channels_last': + return F.layer_norm(x, self.normalized_shape, self.weight, + self.bias, self.eps) + elif self.data_format == 'channels_first': + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class GlobalLocalFilter(nn.Module): + """A GlobalLocalFilter of HorNet. + + Args: + dim (int): Number of input channels. + h (int): Height of complex_weight. + Defaults to 14. + w (int): Width of complex_weight. + Defaults to 8. + """ + + def __init__(self, dim, h=14, w=8): + super().__init__() + self.dw = nn.Conv2d( + dim // 2, + dim // 2, + kernel_size=3, + padding=1, + bias=False, + groups=dim // 2) + self.complex_weight = nn.Parameter( + torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02) + self.pre_norm = HorNetLayerNorm( + dim, eps=1e-6, data_format='channels_first') + self.post_norm = HorNetLayerNorm( + dim, eps=1e-6, data_format='channels_first') + + def forward(self, x): + x = self.pre_norm(x) + x1, x2 = torch.chunk(x, 2, dim=1) + x1 = self.dw(x1) + + x2 = x2.to(torch.float32) + B, C, a, b = x2.shape + x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho') + + weight = self.complex_weight + if not weight.shape[1:3] == x2.shape[2:4]: + weight = F.interpolate( + weight.permute(3, 0, 1, 2), + size=x2.shape[2:4], + mode='bilinear', + align_corners=True).permute(1, 2, 3, 0) + + weight = torch.view_as_complex(weight.contiguous()) + + x2 = x2 * weight + x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho') + + x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], + dim=2).reshape(B, 2 * C, a, b) + x = self.post_norm(x) + return x + + +class gnConv(nn.Module): + """A gnConv of HorNet. + + Args: + dim (int): Number of input channels. + order (int): Order of gnConv. + Defaults to 5. + dw_cfg (dict): The Config for dw conv. + Defaults to ``dict(type='DW', kernel_size=7)``. + scale (float): Scaling parameter of gflayer outputs. + Defaults to 1.0. + """ + + def __init__(self, + dim, + order=5, + dw_cfg=dict(type='DW', kernel_size=7), + scale=1.0): + super().__init__() + self.order = order + self.dims = [dim // 2**i for i in range(order)] + self.dims.reverse() + self.proj_in = nn.Conv2d(dim, 2 * dim, 1) + + cfg = copy.deepcopy(dw_cfg) + dw_type = cfg.pop('type') + assert dw_type in ['DW', 'GF'],\ + 'dw_type should be `DW` or `GF`' + if dw_type == 'DW': + self.dwconv = get_dwconv(sum(self.dims), **cfg) + elif dw_type == 'GF': + self.dwconv = GlobalLocalFilter(sum(self.dims), **cfg) + + self.proj_out = nn.Conv2d(dim, dim, 1) + + self.projs = nn.ModuleList([ + nn.Conv2d(self.dims[i], self.dims[i + 1], 1) + for i in range(order - 1) + ]) + + self.scale = scale + + def forward(self, x): + x = self.proj_in(x) + y, x = torch.split(x, (self.dims[0], sum(self.dims)), dim=1) + + x = self.dwconv(x) * self.scale + + dw_list = torch.split(x, self.dims, dim=1) + x = y * dw_list[0] + + for i in range(self.order - 1): + x = self.projs[i](x) * dw_list[i + 1] + + x = self.proj_out(x) + + return x + + +class HorNetBlock(nn.Module): + """A block of HorNet. + + Args: + dim (int): Number of input channels. + order (int): Order of gnConv. + Defaults to 5. + dw_cfg (dict): The Config for dw conv. + Defaults to ``dict(type='DW', kernel_size=7)``. + scale (float): Scaling parameter of gflayer outputs. + Defaults to 1.0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + use_layer_scale (bool): Whether to use use_layer_scale in HorNet + block. Defaults to True. + """ + + def __init__(self, + dim, + order=5, + dw_cfg=dict(type='DW', kernel_size=7), + scale=1.0, + drop_path_rate=0., + use_layer_scale=True): + super().__init__() + self.out_channels = dim + + self.norm1 = HorNetLayerNorm( + dim, eps=1e-6, data_format='channels_first') + self.gnconv = gnConv(dim, order, dw_cfg, scale) + self.norm2 = HorNetLayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + + if use_layer_scale: + self.gamma1 = LayerScale(dim, data_format='channels_first') + self.gamma2 = LayerScale(dim) + else: + self.gamma1, self.gamma2 = nn.Identity(), nn.Identity() + + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.gamma1(self.gnconv(self.norm1(x)))) + + input = x + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm2(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + x = self.gamma2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +@MODELS.register_module() +class HorNet(BaseBackbone): + """HorNet backbone. + + A PyTorch implementation of paper `HorNet: Efficient High-Order Spatial + Interactions with Recursive Gated Convolutions + `_ . + Inspiration from https://github.com/raoyongming/HorNet + + Args: + arch (str | dict): HorNet architecture. + + If use string, choose from 'tiny', 'small', 'base' and 'large'. + If use dict, it should have below keys: + + - **base_dim** (int): The base dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **orders** (List[int]): The number of order of gnConv in each + stage. + - **dw_cfg** (List[dict]): The Config for dw conv. + + Defaults to 'tiny'. + in_channels (int): Number of input image channels. Defaults to 3. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + scale (float): Scaling parameter of gflayer outputs. Defaults to 1/3. + use_layer_scale (bool): Whether to use use_layer_scale in HorNet + block. Defaults to True. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + gap_before_final_norm (bool): Whether to globally average the feature + map before the final norm layer. In the official repo, it's only + used in classification task. Defaults to True. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'base_dim': 64, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['t-gf', 'tiny-gf'], + {'base_dim': 64, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['s', 'small'], + {'base_dim': 96, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['s-gf', 'small-gf'], + {'base_dim': 96, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['b', 'base'], + {'base_dim': 128, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['b-gf', 'base-gf'], + {'base_dim': 128, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['b-gf384', 'base-gf384'], + {'base_dim': 128, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=24, w=12), + dict(type='GF', h=13, w=7)]}), + **dict.fromkeys(['l', 'large'], + {'base_dim': 192, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}), + **dict.fromkeys(['l-gf', 'large-gf'], + {'base_dim': 192, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=14, w=8), + dict(type='GF', h=7, w=4)]}), + **dict.fromkeys(['l-gf384', 'large-gf384'], + {'base_dim': 192, + 'depths': [2, 3, 18, 2], + 'orders': [2, 3, 4, 5], + 'dw_cfg': [ + dict(type='DW', kernel_size=7), + dict(type='DW', kernel_size=7), + dict(type='GF', h=24, w=12), + dict(type='GF', h=13, w=7)]}), + } # yapf: disable + + def __init__(self, + arch='tiny', + in_channels=3, + drop_path_rate=0., + scale=1 / 3, + use_layer_scale=True, + out_indices=(3, ), + frozen_stages=-1, + with_cp=False, + gap_before_final_norm=True, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if fft is None: + raise RuntimeError( + 'Failed to import torch.fft. Please install "torch>=1.7".') + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'base_dim', 'depths', 'orders', 'dw_cfg'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.scale = scale + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.with_cp = with_cp + self.gap_before_final_norm = gap_before_final_norm + + base_dim = self.arch_settings['base_dim'] + dims = list(map(lambda x: 2**x * base_dim, range(4))) + + self.downsample_layers = nn.ModuleList() + stem = nn.Sequential( + nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4), + HorNetLayerNorm(dims[0], eps=1e-6, data_format='channels_first')) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + HorNetLayerNorm( + dims[i], eps=1e-6, data_format='channels_first'), + nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + total_depth = sum(self.arch_settings['depths']) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + cur_block_idx = 0 + self.stages = nn.ModuleList() + for i in range(4): + stage = nn.Sequential(*[ + HorNetBlock( + dim=dims[i], + order=self.arch_settings['orders'][i], + dw_cfg=self.arch_settings['dw_cfg'][i], + scale=self.scale, + drop_path_rate=dpr[cur_block_idx + j], + use_layer_scale=use_layer_scale) + for j in range(self.arch_settings['depths'][i]) + ]) + self.stages.append(stage) + cur_block_idx += self.arch_settings['depths'][i] + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.stages) + index + assert 0 <= out_indices[i] <= len(self.stages), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + norm_layer = partial( + HorNetLayerNorm, eps=1e-6, data_format='channels_first') + for i_layer in out_indices: + layer = norm_layer(dims[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + def train(self, mode=True): + super(HorNet, self).train(mode) + self._freeze_stages() + + def _freeze_stages(self): + for i in range(0, self.frozen_stages + 1): + # freeze patch embed + m = self.downsample_layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze blocks + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if i in self.out_indices: + # freeze norm + m = getattr(self, f'norm{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + outs = [] + for i in range(4): + x = self.downsample_layers[i](x) + if self.with_cp: + x = checkpoint.checkpoint_sequential(self.stages[i], + len(self.stages[i]), x) + else: + x = self.stages[i](x) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean([-2, -1], keepdim=True) + outs.append(norm_layer(gap).flatten(1)) + else: + # The output of LayerNorm2d may be discontiguous, which + # may cause some problem in the downstream tasks + outs.append(norm_layer(x).contiguous()) + return tuple(outs) diff --git a/mmpretrain/models/backbones/hrnet.py b/mmpretrain/models/backbones/hrnet.py new file mode 100644 index 0000000..99afa90 --- /dev/null +++ b/mmpretrain/models/backbones/hrnet.py @@ -0,0 +1,563 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from .resnet import BasicBlock, Bottleneck, ResLayer, get_expansion + + +class HRModule(BaseModule): + """High-Resolution Module for HRNet. + + In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange + is in this module. + + Args: + num_branches (int): The number of branches. + block (``BaseModule``): Convolution block module. + num_blocks (tuple): The number of blocks in each branch. + The length must be equal to ``num_branches``. + num_channels (tuple): The number of base channels in each branch. + The length must be equal to ``num_branches``. + multiscale_output (bool): Whether to output multi-level features + produced by multiple branches. If False, only the first level + feature will be output. Defaults to True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + conv_cfg (dict, optional): Dictionary to construct and config conv + layer. Defaults to None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to ``dict(type='BN')``. + block_init_cfg (dict, optional): The initialization configs of every + blocks. Defaults to None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + num_branches, + block, + num_blocks, + in_channels, + num_channels, + multiscale_output=True, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + block_init_cfg=None, + init_cfg=None): + super(HRModule, self).__init__(init_cfg) + self.block_init_cfg = block_init_cfg + self._check_branches(num_branches, num_blocks, in_channels, + num_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.multiscale_output = multiscale_output + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.with_cp = with_cp + self.branches = self._make_branches(num_branches, block, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=False) + + def _check_branches(self, num_branches, num_blocks, in_channels, + num_channels): + if num_branches != len(num_blocks): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_BLOCKS({len(num_blocks)})' + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_CHANNELS({len(num_channels)})' + raise ValueError(error_msg) + + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) ' \ + f'!= NUM_INCHANNELS({len(in_channels)})' + raise ValueError(error_msg) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + out_channels = num_channels[i] * get_expansion(block) + branches.append( + ResLayer( + block=block, + num_blocks=num_blocks[i], + in_channels=self.in_channels[i], + out_channels=out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + with_cp=self.with_cp, + init_cfg=self.block_init_cfg, + )) + + return ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + # Upsample the feature maps of smaller scales. + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + # Keep the feature map with the same scale. + fuse_layer.append(None) + else: + # Downsample the feature maps of larger scales. + conv_downsamples = [] + for k in range(i - j): + # Use stacked convolution layers to downsample. + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=False))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = 0 + for j in range(self.num_branches): + if i == j: + y += x[j] + else: + y += self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + return x_fuse + + +@MODELS.register_module() +class HRNet(BaseModule): + """HRNet backbone. + + `High-Resolution Representations for Labeling Pixels and Regions + `_. + + Args: + arch (str): The preset HRNet architecture, includes 'w18', 'w30', + 'w32', 'w40', 'w44', 'w48', 'w64'. It will only be used if + extra is ``None``. Defaults to 'w32'. + extra (dict, optional): Detailed configuration for each stage of HRNet. + There must be 4 stages, the configuration for each stage must have + 5 keys: + + - num_modules (int): The number of HRModule in this stage. + - num_branches (int): The number of branches in the HRModule. + - block (str): The type of convolution block. Please choose between + 'BOTTLENECK' and 'BASIC'. + - num_blocks (tuple): The number of blocks in each branch. + The length must be equal to num_branches. + - num_channels (tuple): The number of base channels in each branch. + The length must be equal to num_branches. + + Defaults to None. + in_channels (int): Number of input image channels. Defaults to 3. + conv_cfg (dict, optional): Dictionary to construct and config conv + layer. Defaults to None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to ``dict(type='BN')``. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Defaults to False. + multiscale_output (bool): Whether to output multi-level features + produced by multiple branches. If False, only the first level + feature will be output. Defaults to True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + + Example: + >>> import torch + >>> from mmpretrain.models import HRNet + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(4, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='BASIC', + >>> num_blocks=(4, 4), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=3, + >>> num_branches=4, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4, 4), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + (1, 64, 4, 4) + (1, 128, 2, 2) + (1, 256, 1, 1) + """ + + blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + arch_zoo = { + # num_modules, num_branches, block, num_blocks, num_channels + 'w18': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (18, 36)], + [4, 3, 'BASIC', (4, 4, 4), (18, 36, 72)], + [3, 4, 'BASIC', (4, 4, 4, 4), (18, 36, 72, 144)]], + 'w30': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (30, 60)], + [4, 3, 'BASIC', (4, 4, 4), (30, 60, 120)], + [3, 4, 'BASIC', (4, 4, 4, 4), (30, 60, 120, 240)]], + 'w32': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (32, 64)], + [4, 3, 'BASIC', (4, 4, 4), (32, 64, 128)], + [3, 4, 'BASIC', (4, 4, 4, 4), (32, 64, 128, 256)]], + 'w40': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (40, 80)], + [4, 3, 'BASIC', (4, 4, 4), (40, 80, 160)], + [3, 4, 'BASIC', (4, 4, 4, 4), (40, 80, 160, 320)]], + 'w44': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (44, 88)], + [4, 3, 'BASIC', (4, 4, 4), (44, 88, 176)], + [3, 4, 'BASIC', (4, 4, 4, 4), (44, 88, 176, 352)]], + 'w48': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (48, 96)], + [4, 3, 'BASIC', (4, 4, 4), (48, 96, 192)], + [3, 4, 'BASIC', (4, 4, 4, 4), (48, 96, 192, 384)]], + 'w64': [[1, 1, 'BOTTLENECK', (4, ), (64, )], + [1, 2, 'BASIC', (4, 4), (64, 128)], + [4, 3, 'BASIC', (4, 4, 4), (64, 128, 256)], + [3, 4, 'BASIC', (4, 4, 4, 4), (64, 128, 256, 512)]], + } # yapf:disable + + def __init__(self, + arch='w32', + extra=None, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + norm_eval=False, + with_cp=False, + zero_init_residual=False, + multiscale_output=True, + init_cfg=[ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(HRNet, self).__init__(init_cfg) + + extra = self.parse_arch(arch, extra) + + # Assert configurations of 4 stages are in extra + for i in range(1, 5): + assert f'stage{i}' in extra, f'Missing stage{i} config in "extra".' + # Assert whether the length of `num_blocks` and `num_channels` are + # equal to `num_branches` + cfg = extra[f'stage{i}'] + assert len(cfg['num_blocks']) == cfg['num_branches'] and \ + len(cfg['num_channels']) == cfg['num_branches'] + + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.zero_init_residual = zero_init_residual + + # -------------------- stem net -------------------- + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + out_channels=64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) + self.add_module(self.norm1_name, norm1) + + self.conv2 = build_conv_layer( + self.conv_cfg, + in_channels=64, + out_channels=64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) + self.add_module(self.norm2_name, norm2) + self.relu = nn.ReLU(inplace=True) + + # -------------------- stage 1 -------------------- + self.stage1_cfg = self.extra['stage1'] + base_channels = self.stage1_cfg['num_channels'] + block_type = self.stage1_cfg['block'] + num_blocks = self.stage1_cfg['num_blocks'] + + block = self.blocks_dict[block_type] + num_channels = [ + channel * get_expansion(block) for channel in base_channels + ] + # To align with the original code, use layer1 instead of stage1 here. + self.layer1 = ResLayer( + block, + in_channels=64, + out_channels=num_channels[0], + num_blocks=num_blocks[0]) + pre_num_channels = num_channels + + # -------------------- stage 2~4 -------------------- + for i in range(2, 5): + stage_cfg = self.extra[f'stage{i}'] + base_channels = stage_cfg['num_channels'] + block = self.blocks_dict[stage_cfg['block']] + multiscale_output_ = multiscale_output if i == 4 else True + + num_channels = [ + channel * get_expansion(block) for channel in base_channels + ] + # The transition layer from layer1 to stage2 + transition = self._make_transition_layer(pre_num_channels, + num_channels) + self.add_module(f'transition{i-1}', transition) + stage = self._make_stage( + stage_cfg, num_channels, multiscale_output=multiscale_output_) + self.add_module(f'stage{i}', stage) + + pre_num_channels = num_channels + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + # For existing scale branches, + # add conv block when the channels are not the same. + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU(inplace=True))) + else: + transition_layers.append(nn.Identity()) + else: + # For new scale branches, add stacked downsample conv blocks. + # For example, num_branches_pre = 2, for the 4th branch, add + # stacked two downsample conv blocks. + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_stage(self, layer_config, in_channels, multiscale_output=True): + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + + hr_modules = [] + block_init_cfg = None + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + for i in range(num_modules): + # multi_scale_output is only used for the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + hr_modules.append( + HRModule( + num_branches, + block, + num_blocks, + in_channels, + num_channels, + reset_multiscale_output, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + block_init_cfg=block_init_cfg)) + + return Sequential(*hr_modules) + + def forward(self, x): + """Forward function.""" + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [x] + + for i in range(2, 5): + # Apply transition + transition = getattr(self, f'transition{i-1}') + inputs = [] + for j, layer in enumerate(transition): + if j < len(x_list): + inputs.append(layer(x_list[j])) + else: + inputs.append(layer(x_list[-1])) + # Forward HRModule + stage = getattr(self, f'stage{i}') + x_list = stage(inputs) + + return tuple(x_list) + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super(HRNet, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def parse_arch(self, arch, extra=None): + if extra is not None: + return extra + + assert arch in self.arch_zoo, \ + ('Invalid arch, please choose arch from ' + f'{list(self.arch_zoo.keys())}, or specify `extra` ' + 'argument directly.') + + extra = dict() + for i, stage_setting in enumerate(self.arch_zoo[arch], start=1): + extra[f'stage{i}'] = dict( + num_modules=stage_setting[0], + num_branches=stage_setting[1], + block=stage_setting[2], + num_blocks=stage_setting[3], + num_channels=stage_setting[4], + ) + + return extra diff --git a/mmpretrain/models/backbones/inception_v3.py b/mmpretrain/models/backbones/inception_v3.py new file mode 100644 index 0000000..1d6c04b --- /dev/null +++ b/mmpretrain/models/backbones/inception_v3.py @@ -0,0 +1,501 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_conv_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class BasicConv2d(BaseModule): + """A basic convolution block including convolution, batch norm and ReLU. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + conv_cfg (dict, optional): The config of convolution layer. + Defaults to None, which means to use ``nn.Conv2d``. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + **kwargs: Other keyword arguments of the convolution layer. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None, + **kwargs) -> None: + super().__init__(init_cfg=init_cfg) + self.conv = build_conv_layer( + conv_cfg, in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class InceptionA(BaseModule): + """Type-A Inception block. + + Args: + in_channels (int): The number of input channels. + pool_features (int): The number of channels in pooling branch. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + pool_features: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.branch1x1 = BasicConv2d( + in_channels, 64, kernel_size=1, conv_cfg=conv_cfg) + + self.branch5x5_1 = BasicConv2d( + in_channels, 48, kernel_size=1, conv_cfg=conv_cfg) + self.branch5x5_2 = BasicConv2d( + 48, 64, kernel_size=5, padding=2, conv_cfg=conv_cfg) + + self.branch3x3dbl_1 = BasicConv2d( + in_channels, 64, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3dbl_2 = BasicConv2d( + 64, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg) + self.branch3x3dbl_3 = BasicConv2d( + 96, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg) + + self.branch_pool_downsample = nn.AvgPool2d( + kernel_size=3, stride=1, padding=1) + self.branch_pool = BasicConv2d( + in_channels, pool_features, kernel_size=1, conv_cfg=conv_cfg) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = self.branch_pool_downsample(x) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionB(BaseModule): + """Type-B Inception block. + + Args: + in_channels (int): The number of input channels. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.branch3x3 = BasicConv2d( + in_channels, 384, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch3x3dbl_1 = BasicConv2d( + in_channels, 64, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3dbl_2 = BasicConv2d( + 64, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg) + self.branch3x3dbl_3 = BasicConv2d( + 96, 96, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = self.branch_pool(x) + + outputs = [branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionC(BaseModule): + """Type-C Inception block. + + Args: + in_channels (int): The number of input channels. + channels_7x7 (int): The number of channels in 7x7 convolution branch. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + channels_7x7: int, + conv_cfg: Optional[dict] = None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.branch1x1 = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + + c7 = channels_7x7 + self.branch7x7_1 = BasicConv2d( + in_channels, c7, kernel_size=1, conv_cfg=conv_cfg) + self.branch7x7_2 = BasicConv2d( + c7, c7, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + self.branch7x7_3 = BasicConv2d( + c7, 192, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + + self.branch7x7dbl_1 = BasicConv2d( + in_channels, c7, kernel_size=1, conv_cfg=conv_cfg) + self.branch7x7dbl_2 = BasicConv2d( + c7, c7, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + self.branch7x7dbl_3 = BasicConv2d( + c7, c7, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + self.branch7x7dbl_4 = BasicConv2d( + c7, c7, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + self.branch7x7dbl_5 = BasicConv2d( + c7, 192, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + + self.branch_pool_downsample = nn.AvgPool2d( + kernel_size=3, stride=1, padding=1) + self.branch_pool = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = self.branch_pool_downsample(x) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionD(BaseModule): + """Type-D Inception block. + + Args: + in_channels (int): The number of input channels. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.branch3x3_1 = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3_2 = BasicConv2d( + 192, 320, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch7x7x3_1 = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + self.branch7x7x3_2 = BasicConv2d( + 192, 192, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg) + self.branch7x7x3_3 = BasicConv2d( + 192, 192, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg) + self.branch7x7x3_4 = BasicConv2d( + 192, 192, kernel_size=3, stride=2, conv_cfg=conv_cfg) + + self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = self.branch_pool(x) + outputs = [branch3x3, branch7x7x3, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionE(BaseModule): + """Type-E Inception block. + + Args: + in_channels (int): The number of input channels. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + conv_cfg: Optional[dict] = None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.branch1x1 = BasicConv2d( + in_channels, 320, kernel_size=1, conv_cfg=conv_cfg) + + self.branch3x3_1 = BasicConv2d( + in_channels, 384, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3_2a = BasicConv2d( + 384, 384, kernel_size=(1, 3), padding=(0, 1), conv_cfg=conv_cfg) + self.branch3x3_2b = BasicConv2d( + 384, 384, kernel_size=(3, 1), padding=(1, 0), conv_cfg=conv_cfg) + + self.branch3x3dbl_1 = BasicConv2d( + in_channels, 448, kernel_size=1, conv_cfg=conv_cfg) + self.branch3x3dbl_2 = BasicConv2d( + 448, 384, kernel_size=3, padding=1, conv_cfg=conv_cfg) + self.branch3x3dbl_3a = BasicConv2d( + 384, 384, kernel_size=(1, 3), padding=(0, 1), conv_cfg=conv_cfg) + self.branch3x3dbl_3b = BasicConv2d( + 384, 384, kernel_size=(3, 1), padding=(1, 0), conv_cfg=conv_cfg) + + self.branch_pool_downsample = nn.AvgPool2d( + kernel_size=3, stride=1, padding=1) + self.branch_pool = BasicConv2d( + in_channels, 192, kernel_size=1, conv_cfg=conv_cfg) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + branch_pool = self.branch_pool_downsample(x) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class InceptionAux(BaseModule): + """The Inception block for the auxiliary classification branch. + + Args: + in_channels (int): The number of input channels. + num_classes (int): The number of categroies. + conv_cfg (dict, optional): The convolution layer config in the + :class:`BasicConv2d` block. Defaults to None. + init_cfg (dict, optional): The config of initialization. + Defaults to use trunc normal with ``std=0.01`` for Conv2d layers + and use trunc normal with ``std=0.001`` for Linear layers.. + """ + + def __init__(self, + in_channels: int, + num_classes: int, + conv_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = [ + dict(type='TruncNormal', layer='Conv2d', std=0.01), + dict(type='TruncNormal', layer='Linear', std=0.001) + ]): + super().__init__(init_cfg=init_cfg) + self.downsample = nn.AvgPool2d(kernel_size=5, stride=3) + self.conv0 = BasicConv2d( + in_channels, 128, kernel_size=1, conv_cfg=conv_cfg) + self.conv1 = BasicConv2d(128, 768, kernel_size=5, conv_cfg=conv_cfg) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(768, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + # N x 768 x 17 x 17 + x = self.downsample(x) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = self.gap(x) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +@MODELS.register_module() +class InceptionV3(BaseBackbone): + """Inception V3 backbone. + + A PyTorch implementation of `Rethinking the Inception Architecture for + Computer Vision `_ + + This implementation is modified from + https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py. + Licensed under the BSD 3-Clause License. + + Args: + num_classes (int): The number of categroies. Defaults to 1000. + aux_logits (bool): Whether to enable the auxiliary branch. If False, + the auxiliary logits output will be None. Defaults to False. + dropout (float): Dropout rate. Defaults to 0.5. + init_cfg (dict, optional): The config of initialization. Defaults + to use trunc normal with ``std=0.1`` for all Conv2d and Linear + layers and constant with ``val=1`` for all BatchNorm2d layers. + + Example: + >>> import torch + >>> from mmpretrain.models import build_backbone + >>> + >>> inputs = torch.rand(2, 3, 299, 299) + >>> cfg = dict(type='InceptionV3', num_classes=100) + >>> backbone = build_backbone(cfg) + >>> aux_out, out = backbone(inputs) + >>> # The auxiliary branch is disabled by default. + >>> assert aux_out is None + >>> print(out.shape) + torch.Size([2, 100]) + >>> cfg = dict(type='InceptionV3', num_classes=100, aux_logits=True) + >>> backbone = build_backbone(cfg) + >>> aux_out, out = backbone(inputs) + >>> print(aux_out.shape, out.shape) + torch.Size([2, 100]) torch.Size([2, 100]) + """ + + def __init__( + self, + num_classes: int = 1000, + aux_logits: bool = False, + dropout: float = 0.5, + init_cfg: Optional[dict] = [ + dict(type='TruncNormal', layer=['Conv2d', 'Linear'], std=0.1), + dict(type='Constant', layer='BatchNorm2d', val=1) + ], + ) -> None: + super().__init__(init_cfg=init_cfg) + + self.aux_logits = aux_logits + self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) + self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Mixed_5b = InceptionA(192, pool_features=32) + self.Mixed_5c = InceptionA(256, pool_features=64) + self.Mixed_5d = InceptionA(288, pool_features=64) + self.Mixed_6a = InceptionB(288) + self.Mixed_6b = InceptionC(768, channels_7x7=128) + self.Mixed_6c = InceptionC(768, channels_7x7=160) + self.Mixed_6d = InceptionC(768, channels_7x7=160) + self.Mixed_6e = InceptionC(768, channels_7x7=192) + self.AuxLogits: Optional[nn.Module] = None + if aux_logits: + self.AuxLogits = InceptionAux(768, num_classes) + self.Mixed_7a = InceptionD(768) + self.Mixed_7b = InceptionE(1280) + self.Mixed_7c = InceptionE(2048) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.dropout = nn.Dropout(p=dropout) + self.fc = nn.Linear(2048, num_classes) + + def forward( + self, + x: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + """Forward function.""" + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = self.maxpool1(x) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = self.maxpool2(x) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + aux: Optional[torch.Tensor] = None + if self.aux_logits and self.training: + aux = self.AuxLogits(x) + # N x 768 x 17 x 17 + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + # Adaptive average pooling + x = self.avgpool(x) + # N x 2048 x 1 x 1 + x = self.dropout(x) + # N x 2048 x 1 x 1 + x = torch.flatten(x, 1) + # N x 2048 + x = self.fc(x) + # N x 1000 (num_classes) + return aux, x diff --git a/mmpretrain/models/backbones/lenet.py b/mmpretrain/models/backbones/lenet.py new file mode 100644 index 0000000..8e423c0 --- /dev/null +++ b/mmpretrain/models/backbones/lenet.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class LeNet5(BaseBackbone): + """`LeNet5 `_ backbone. + + The input for LeNet-5 is a 32×32 grayscale image. + + Args: + num_classes (int): number of classes for classification. + The default value is -1, which uses the backbone as + a feature extractor without the top classifier. + """ + + def __init__(self, num_classes=-1): + super(LeNet5, self).__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d(1, 6, kernel_size=5, stride=1), nn.Tanh(), + nn.AvgPool2d(kernel_size=2), + nn.Conv2d(6, 16, kernel_size=5, stride=1), nn.Tanh(), + nn.AvgPool2d(kernel_size=2), + nn.Conv2d(16, 120, kernel_size=5, stride=1), nn.Tanh()) + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(120, 84), + nn.Tanh(), + nn.Linear(84, num_classes), + ) + + def forward(self, x): + + x = self.features(x) + if self.num_classes > 0: + x = self.classifier(x.squeeze()) + + return (x, ) diff --git a/mmpretrain/models/backbones/levit.py b/mmpretrain/models/backbones/levit.py new file mode 100644 index 0000000..5f7aa32 --- /dev/null +++ b/mmpretrain/models/backbones/levit.py @@ -0,0 +1,522 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, fuse_conv_bn +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +class HybridBackbone(BaseModule): + + def __init__( + self, + embed_dim, + kernel_size=3, + stride=2, + pad=1, + dilation=1, + groups=1, + act_cfg=dict(type='HSwish'), + conv_cfg=None, + norm_cfg=dict(type='BN'), + init_cfg=None, + ): + super(HybridBackbone, self).__init__(init_cfg=init_cfg) + + self.input_channels = [ + 3, embed_dim // 8, embed_dim // 4, embed_dim // 2 + ] + self.output_channels = [ + embed_dim // 8, embed_dim // 4, embed_dim // 2, embed_dim + ] + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.patch_embed = Sequential() + + for i in range(len(self.input_channels)): + conv_bn = ConvolutionBatchNorm( + self.input_channels[i], + self.output_channels[i], + kernel_size=kernel_size, + stride=stride, + pad=pad, + dilation=dilation, + groups=groups, + norm_cfg=norm_cfg, + ) + self.patch_embed.add_module('%d' % (2 * i), conv_bn) + if i < len(self.input_channels) - 1: + self.patch_embed.add_module('%d' % (i * 2 + 1), + build_activation_layer(act_cfg)) + + def forward(self, x): + x = self.patch_embed(x) + return x + + +class ConvolutionBatchNorm(BaseModule): + + def __init__( + self, + in_channel, + out_channel, + kernel_size=3, + stride=2, + pad=1, + dilation=1, + groups=1, + norm_cfg=dict(type='BN'), + ): + super(ConvolutionBatchNorm, self).__init__() + self.conv = nn.Conv2d( + in_channel, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=pad, + dilation=dilation, + groups=groups, + bias=False) + self.bn = build_norm_layer(norm_cfg, out_channel) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + @torch.no_grad() + def fuse(self): + return fuse_conv_bn(self).conv + + +class LinearBatchNorm(BaseModule): + + def __init__(self, in_feature, out_feature, norm_cfg=dict(type='BN1d')): + super(LinearBatchNorm, self).__init__() + self.linear = nn.Linear(in_feature, out_feature, bias=False) + self.bn = build_norm_layer(norm_cfg, out_feature) + + def forward(self, x): + x = self.linear(x) + x = self.bn(x.flatten(0, 1)).reshape_as(x) + return x + + @torch.no_grad() + def fuse(self): + w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5 + w = self.linear.weight * w[:, None] + b = self.bn.bias - self.bn.running_mean * self.bn.weight / \ + (self.bn.running_var + self.bn.eps) ** 0.5 + + factory_kwargs = { + 'device': self.linear.weight.device, + 'dtype': self.linear.weight.dtype + } + bias = nn.Parameter( + torch.empty(self.linear.out_features, **factory_kwargs)) + self.linear.register_parameter('bias', bias) + self.linear.weight.data.copy_(w) + self.linear.bias.data.copy_(b) + return self.linear + + +class Residual(BaseModule): + + def __init__(self, block, drop_path_rate=0.): + super(Residual, self).__init__() + self.block = block + if drop_path_rate > 0: + self.drop_path = DropPath(drop_path_rate) + else: + self.drop_path = nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.block(x)) + return x + + +class Attention(BaseModule): + + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + act_cfg=dict(type='HSwish'), + resolution=14, + ): + super(Attention, self).__init__() + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + self.qkv = LinearBatchNorm(dim, h) + self.proj = nn.Sequential( + build_activation_layer(act_cfg), LinearBatchNorm(self.dh, dim)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N)) + + @torch.no_grad() + def train(self, mode=True): + """change the mode of model.""" + super(Attention, self).train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, C = x.shape # 2 196 128 + qkv = self.qkv(x) # 2 196 128 + q, k, v = qkv.view(B, N, self.num_heads, -1).split( + [self.key_dim, self.key_dim, self.d], + dim=3) # q 2 196 4 16 ; k 2 196 4 16; v 2 196 4 32 + q = q.permute(0, 2, 1, 3) # 2 4 196 16 + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ((q @ k.transpose(-2, -1)) * + self.scale # 2 4 196 16 * 2 4 16 196 -> 2 4 196 196 + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab)) + attn = attn.softmax(dim=-1) # 2 4 196 196 -> 2 4 196 196 + x = (attn @ v).transpose(1, 2).reshape( + B, N, + self.dh) # 2 4 196 196 * 2 4 196 32 -> 2 4 196 32 -> 2 196 128 + x = self.proj(x) + return x + + +class MLP(nn.Sequential): + + def __init__(self, embed_dim, mlp_ratio, act_cfg=dict(type='HSwish')): + super(MLP, self).__init__() + h = embed_dim * mlp_ratio + self.linear1 = LinearBatchNorm(embed_dim, h) + self.activation = build_activation_layer(act_cfg) + self.linear2 = LinearBatchNorm(h, embed_dim) + + def forward(self, x): + x = self.linear1(x) + x = self.activation(x) + x = self.linear2(x) + return x + + +class Subsample(BaseModule): + + def __init__(self, stride, resolution): + super(Subsample, self).__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, _, C = x.shape + # B, N, C -> B, H, W, C + x = x.view(B, self.resolution, self.resolution, C) + x = x[:, ::self.stride, ::self.stride] + x = x.reshape(B, -1, C) # B, H', W', C -> B, N', C + return x + + +class AttentionSubsample(nn.Sequential): + + def __init__(self, + in_dim, + out_dim, + key_dim, + num_heads=8, + attn_ratio=2, + act_cfg=dict(type='HSwish'), + stride=2, + resolution=14): + super(AttentionSubsample, self).__init__() + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.sub_resolution = (resolution - 1) // stride + 1 + h = self.dh + nh_kd + self.kv = LinearBatchNorm(in_dim, h) + + self.q = nn.Sequential( + Subsample(stride, resolution), LinearBatchNorm(in_dim, nh_kd)) + self.proj = nn.Sequential( + build_activation_layer(act_cfg), LinearBatchNorm(self.dh, out_dim)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + sub_points = list( + itertools.product( + range(self.sub_resolution), range(self.sub_resolution))) + N = len(points) + N_sub = len(sub_points) + attention_offsets = {} + idxs = [] + for p1 in sub_points: + for p2 in points: + size = 1 + offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N_sub, N)) + + @torch.no_grad() + def train(self, mode=True): + super(AttentionSubsample, self).train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, + -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.sub_resolution**2, self.num_heads, + self.key_dim).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + \ + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = self.proj(x) + return x + + +@MODELS.register_module() +class LeViT(BaseBackbone): + """LeViT backbone. + + A PyTorch implementation of `LeViT: A Vision Transformer in ConvNet's + Clothing for Faster Inference `_ + + Modified from the official implementation: + https://github.com/facebookresearch/LeViT + + Args: + arch (str | dict): LeViT architecture. + + If use string, choose from '128s', '128', '192', '256' and '384'. + If use dict, it should have below keys: + + - **embed_dims** (List[int]): The embed dimensions of each stage. + - **key_dims** (List[int]): The embed dimensions of the key in the + attention layers of each stage. + - **num_heads** (List[int]): The number of heads in each stage. + - **depths** (List[int]): The number of blocks in each stage. + + img_size (int): Input image size + patch_size (int | tuple): The patch size. Deault to 16 + attn_ratio (int): Ratio of hidden dimensions of the value in attention + layers. Defaults to 2. + mlp_ratio (int): Ratio of hidden dimensions in MLP layers. + Defaults to 2. + act_cfg (dict): The config of activation functions. + Defaults to ``dict(type='HSwish')``. + hybrid_backbone (callable): A callable object to build the patch embed + module. Defaults to use :class:`HybridBackbone`. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + deploy (bool): Whether to switch the model structure to + deployment mode. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + '128s': { + 'embed_dims': [128, 256, 384], + 'num_heads': [4, 6, 8], + 'depths': [2, 3, 4], + 'key_dims': [16, 16, 16], + }, + '128': { + 'embed_dims': [128, 256, 384], + 'num_heads': [4, 8, 12], + 'depths': [4, 4, 4], + 'key_dims': [16, 16, 16], + }, + '192': { + 'embed_dims': [192, 288, 384], + 'num_heads': [3, 5, 6], + 'depths': [4, 4, 4], + 'key_dims': [32, 32, 32], + }, + '256': { + 'embed_dims': [256, 384, 512], + 'num_heads': [4, 6, 8], + 'depths': [4, 4, 4], + 'key_dims': [32, 32, 32], + }, + '384': { + 'embed_dims': [384, 512, 768], + 'num_heads': [6, 9, 12], + 'depths': [4, 4, 4], + 'key_dims': [32, 32, 32], + }, + } + + def __init__(self, + arch, + img_size=224, + patch_size=16, + attn_ratio=2, + mlp_ratio=2, + act_cfg=dict(type='HSwish'), + hybrid_backbone=HybridBackbone, + out_indices=-1, + deploy=False, + drop_path_rate=0, + init_cfg=None): + super(LeViT, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch = self.arch_zoo[arch] + elif isinstance(arch, dict): + essential_keys = {'embed_dim', 'num_heads', 'depth', 'key_dim'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch = arch + else: + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + self.embed_dims = self.arch['embed_dims'] + self.num_heads = self.arch['num_heads'] + self.key_dims = self.arch['key_dims'] + self.depths = self.arch['depths'] + self.num_stages = len(self.embed_dims) + self.drop_path_rate = drop_path_rate + + self.patch_embed = hybrid_backbone(self.embed_dims[0]) + + self.resolutions = [] + resolution = img_size // patch_size + self.stages = ModuleList() + for i, (embed_dims, key_dims, depth, num_heads) in enumerate( + zip(self.embed_dims, self.key_dims, self.depths, + self.num_heads)): + blocks = [] + if i > 0: + downsample = AttentionSubsample( + in_dim=self.embed_dims[i - 1], + out_dim=embed_dims, + key_dim=key_dims, + num_heads=self.embed_dims[i - 1] // key_dims, + attn_ratio=4, + act_cfg=act_cfg, + stride=2, + resolution=resolution) + blocks.append(downsample) + resolution = downsample.sub_resolution + if mlp_ratio > 0: # mlp_ratio + blocks.append( + Residual( + MLP(embed_dims, mlp_ratio, act_cfg=act_cfg), + self.drop_path_rate)) + self.resolutions.append(resolution) + for _ in range(depth): + blocks.append( + Residual( + Attention( + embed_dims, + key_dims, + num_heads, + attn_ratio=attn_ratio, + act_cfg=act_cfg, + resolution=resolution, + ), self.drop_path_rate)) + if mlp_ratio > 0: + blocks.append( + Residual( + MLP(embed_dims, mlp_ratio, act_cfg=act_cfg), + self.drop_path_rate)) + + self.stages.append(Sequential(*blocks)) + + if isinstance(out_indices, int): + out_indices = [out_indices] + elif isinstance(out_indices, tuple): + out_indices = list(out_indices) + elif not isinstance(out_indices, list): + raise TypeError('"out_indices" must by a list, tuple or int, ' + f'get {type(out_indices)} instead.') + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert 0 <= out_indices[i] < self.num_stages, \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + self.deploy = False + if deploy: + self.switch_to_deploy() + + def switch_to_deploy(self): + if self.deploy: + return + fuse_parameters(self) + self.deploy = True + + def forward(self, x): + x = self.patch_embed(x) + x = x.flatten(2).transpose(1, 2) # B, C, H, W -> B, L, C + outs = [] + for i, stage in enumerate(self.stages): + x = stage(x) + B, _, C = x.shape + if i in self.out_indices: + out = x.reshape(B, self.resolutions[i], self.resolutions[i], C) + out = out.permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + + +def fuse_parameters(module): + for child_name, child in module.named_children(): + if hasattr(child, 'fuse'): + setattr(module, child_name, child.fuse()) + else: + fuse_parameters(child) diff --git a/mmpretrain/models/backbones/mixmim.py b/mmpretrain/models/backbones/mixmim.py new file mode 100644 index 0000000..2c67aa0 --- /dev/null +++ b/mmpretrain/models/backbones/mixmim.py @@ -0,0 +1,533 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import DropPath +from mmcv.cnn.bricks.transformer import PatchEmbed, PatchMerging +from mmengine.model import BaseModule +from torch import nn +from torch.utils.checkpoint import checkpoint + +from mmpretrain.registry import MODELS +from ..utils import WindowMSA, to_2tuple +from .base_backbone import BaseBackbone +from .vision_transformer import TransformerEncoderLayer + + +class MixMIMWindowAttention(WindowMSA): + """MixMIM Window Attention. + + Compared with WindowMSA, we add some modifications + in ``forward`` to meet the requirement of MixMIM during + pretraining. + + Implements one windown attention in MixMIM. + Args: + embed_dims (int): The feature dimension. + window_size (list): The height and width of the window. + num_heads (int): The number of head in attention. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None): + + super().__init__( + embed_dims=embed_dims, + window_size=window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop_rate, + proj_drop=proj_drop_rate, + init_cfg=init_cfg) + + def forward(self, x, mask=None): + + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + mask = mask.reshape(B_, 1, 1, N) + mask_new = mask * mask.transpose( + 2, 3) + (1 - mask) * (1 - mask).transpose(2, 3) + mask_new = 1 - mask_new + + if mask_new.dtype == torch.float16: + attn = attn - 65500 * mask_new + else: + attn = attn - 1e30 * mask_new + + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MixMIMBlock(TransformerEncoderLayer): + """MixMIM Block. Implements one block in MixMIM. + + Args: + embed_dims (int): The feature dimension. + input_resolution (tuple): Input resolution of this layer. + num_heads (int): The number of head in attention, + window_size (list): The height and width of the window. + mlp_ratio (int): The MLP ration in FFN. + num_fcs (int): The number of linear layers in a block. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. + Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + input_resolution, + num_heads, + window_size=7, + mlp_ratio=4., + num_fcs=2, + qkv_bias=True, + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=int(mlp_ratio * embed_dims), + drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + if min(self.input_resolution) <= self.window_size: + self.window_size = min(self.input_resolution) + + self.attn = MixMIMWindowAttention( + embed_dims=embed_dims, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate) + + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + @staticmethod + def window_reverse(windows, H, W, window_size): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + @staticmethod + def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + def forward(self, x, attn_mask=None): + H, W = self.input_resolution + B, L, C = x.shape + + shortcut = x + x = self.ln1(x) + x = x.view(B, H, W, C) + + # partition windows + x_windows = self.window_partition( + x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + if attn_mask is not None: + attn_mask = attn_mask.repeat(B, 1, 1) # B, N, 1 + attn_mask = attn_mask.view(B, H, W, 1) + attn_mask = self.window_partition(attn_mask, self.window_size) + attn_mask = attn_mask.view(-1, self.window_size * self.window_size, + 1) + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + x = self.window_reverse(attn_windows, H, W, + self.window_size) # B H' W' C + + x = x.view(B, H * W, C) + + x = shortcut + self.drop_path(x) + + x = self.ffn(self.norm2(x), identity=x) # ffn contains DropPath + + return x + + +class MixMIMLayer(BaseModule): + """Implements one MixMIM layer, which may contains several MixMIM blocks. + + Args: + embed_dims (int): The feature dimension. + input_resolution (tuple): Input resolution of this layer. + depth (int): The number of blocks in this layer. + num_heads (int): The number of head in attention, + window_size (list): The height and width of the window. + mlp_ratio (int): The MLP ration in FFN. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. + Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + downsample (class, optional): Downsample the output of blocks b + y patch merging.Defaults to None. + use_checkpoint (bool): Whether use the checkpoint to + reduce GPU memory cost. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + input_resolution: int, + depth: int, + num_heads: int, + window_size: int, + mlp_ratio=4., + qkv_bias=True, + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=[0.], + norm_cfg=dict(type='LN'), + downsample=None, + use_checkpoint=False, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList() + for i in range(depth): + self.blocks.append( + MixMIMBlock( + embed_dims=embed_dims, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate[i], + norm_cfg=norm_cfg)) + # patch merging layer + if downsample is not None: + self.downsample = downsample( + in_channels=embed_dims, + out_channels=2 * embed_dims, + norm_cfg=norm_cfg) + else: + self.downsample = None + + def forward(self, x, attn_mask=None): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask=attn_mask) + if self.downsample is not None: + x, _ = self.downsample(x, self.input_resolution) + return x + + def extra_repr(self) -> str: + return f'dim={self.embed_dims}, \ + input_resolution={self.input_resolution}, depth={self.depth}' + + +@MODELS.register_module() +class MixMIMTransformer(BaseBackbone): + """MixMIM backbone. + + A PyTorch implement of : ` MixMIM: Mixed and Masked Image + Modeling for Efficient Visual Representation Learning + `_ + + Args: + arch (str | dict): MixMIM architecture. If use string, + choose from 'base','large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + + Defaults to 'base'. + mlp_ratio (int): The mlp ratio in FFN. Defaults to 4. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to mlp_ratio + the most common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + window_size (list): The height and width of the window. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + attn_drop_rate (float): attention drop rate. Defaults to 0. + use_checkpoint (bool): Whether use the checkpoint to + reduce GPU memory cost. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32] + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48] + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 352, + 'depths': [2, 2, 18, 2], + 'num_heads': [11, 22, 44, 88] + }), + } + + def __init__( + self, + arch='base', + mlp_ratio=4, + img_size=224, + patch_size=4, + in_channels=3, + window_size=[14, 14, 14, 7], + qkv_bias=True, + patch_cfg=dict(), + norm_cfg=dict(type='LN'), + drop_rate=0.0, + drop_path_rate=0.0, + attn_drop_rate=0.0, + use_checkpoint=False, + init_cfg: Optional[dict] = None, + ) -> None: + super(MixMIMTransformer, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + + self.encoder_stride = 32 + + self.num_layers = len(self.depths) + self.qkv_bias = qkv_bias + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.use_checkpoint = use_checkpoint + self.mlp_ratio = mlp_ratio + self.window_size = window_size + + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + self.dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + self.layers.append( + MixMIMLayer( + embed_dims=int(self.embed_dims * 2**i_layer), + input_resolution=(self.patch_resolution[0] // (2**i_layer), + self.patch_resolution[1] // + (2**i_layer)), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size[i_layer], + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + proj_drop_rate=self.drop_rate, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=self.dpr[sum(self.depths[:i_layer] + ):sum(self.depths[:i_layer + + 1])], + norm_cfg=norm_cfg, + downsample=PatchMerging if + (i_layer < self.num_layers - 1) else None, + use_checkpoint=self.use_checkpoint)) + + self.num_features = int(self.embed_dims * 2**(self.num_layers - 1)) + self.drop_after_pos = nn.Dropout(p=self.drop_rate) + + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, self.embed_dims), + requires_grad=False) + + _, self.norm = build_norm_layer(norm_cfg, self.num_features) + + def forward(self, x: torch.Tensor): + x, _ = self.patch_embed(x) + + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + for layer in self.layers: + x = layer(x, attn_mask=None) + + x = self.norm(x) + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + + return (x, ) + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = sum(self.depths) + 2 + + if not param_name.startswith(prefix): + # For subsequent module like neck and head + if param_name.startswith('neck'): + return num_layers - 2, num_layers + else: + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + stem_layers = ('patch_embed', 'absolute_pos_embed', 'pos_embed') + if any(stem in param_name for stem in stem_layers): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + block_id = param_name.split('.')[3] + + if block_id in ('downsample', 'reduction', 'norm'): + layer_depth = sum(self.depths[:layer_id + 1]) + else: + layer_depth = sum(self.depths[:layer_id]) + int(block_id) + 1 + else: + layer_depth = num_layers - 2 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/mlp_mixer.py b/mmpretrain/models/backbones/mlp_mixer.py new file mode 100644 index 0000000..26fb8ce --- /dev/null +++ b/mmpretrain/models/backbones/mlp_mixer.py @@ -0,0 +1,263 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from ..utils import to_2tuple +from .base_backbone import BaseBackbone + + +class MixerBlock(BaseModule): + """Mlp-Mixer basic block. + + Basic module of `MLP-Mixer: An all-MLP Architecture for Vision + `_ + + Args: + num_tokens (int): The number of patched tokens + embed_dims (int): The feature dimension + tokens_mlp_dims (int): The hidden dimension for tokens FFNs + channels_mlp_dims (int): The hidden dimension for channels FFNs + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + num_tokens, + embed_dims, + tokens_mlp_dims, + channels_mlp_dims, + drop_rate=0., + drop_path_rate=0., + num_fcs=2, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(MixerBlock, self).__init__(init_cfg=init_cfg) + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + self.token_mix = FFN( + embed_dims=num_tokens, + feedforward_channels=tokens_mlp_dims, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + self.channel_mix = FFN( + embed_dims=embed_dims, + feedforward_channels=channels_mlp_dims, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def init_weights(self): + super(MixerBlock, self).init_weights() + for m in self.token_mix.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + for m in self.channel_mix.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + out = self.norm1(x).transpose(1, 2) + x = x + self.token_mix(out).transpose(1, 2) + x = self.channel_mix(self.norm2(x), identity=x) + return x + + +@MODELS.register_module() +class MlpMixer(BaseBackbone): + """Mlp-Mixer backbone. + + Pytorch implementation of `MLP-Mixer: An all-MLP Architecture for Vision + `_ + + Args: + arch (str | dict): MLP Mixer architecture. If use string, choose from + 'small', 'base' and 'large'. If use dict, it should have below + keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of MLP blocks. + - **tokens_mlp_dims** (int): The hidden dimensions for tokens FFNs. + - **channels_mlp_dims** (int): The The hidden dimensions for + channels FFNs. + + Defaults to 'base'. + img_size (int | tuple): The input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + out_indices (Sequence | int): Output from which layer. + Defaults to -1, means the last layer. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + act_cfg (dict): The activation config for FFNs. Default GELU. + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each mixer block layer. + Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 512, + 'num_layers': 8, + 'tokens_mlp_dims': 256, + 'channels_mlp_dims': 2048, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'tokens_mlp_dims': 384, + 'channels_mlp_dims': 3072, + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'tokens_mlp_dims': 512, + 'channels_mlp_dims': 4096, + }), + } + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super(MlpMixer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'tokens_mlp_dims', + 'channels_mlp_dims' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.tokens_mlp_dims = self.arch_settings['tokens_mlp_dims'] + self.channels_mlp_dims = self.arch_settings['channels_mlp_dims'] + + self.img_size = to_2tuple(img_size) + + _patch_cfg = dict( + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must be a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + else: + assert index >= self.num_layers, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + num_tokens=num_patches, + embed_dims=self.embed_dims, + tokens_mlp_dims=self.tokens_mlp_dims, + channels_mlp_dims=self.channels_mlp_dims, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + ) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(MixerBlock(**_layer_cfg)) + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def forward(self, x): + assert x.shape[2:] == self.img_size, \ + "The MLP-Mixer doesn't support dynamic input shape. " \ + f'Please input images with shape {self.img_size}' + x, _ = self.patch_embed(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1: + x = self.norm1(x) + + if i in self.out_indices: + out = x.transpose(1, 2) + outs.append(out) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/mobilenet_v2.py b/mmpretrain/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000..bca1418 --- /dev/null +++ b/mmpretrain/models/backbones/mobilenet_v2.py @@ -0,0 +1,264 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.utils import make_divisible +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class InvertedResidual(BaseModule): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False, + init_cfg=None): + super(InvertedResidual, self).__init__(init_cfg) + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=1, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class MobileNetV2(BaseBackbone): + """MobileNetV2 backbone. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks, stride. + arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], + [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], + [6, 320, 1, 1]] + + def __init__(self, + widen_factor=1., + out_indices=(7, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(MobileNetV2, self).__init__(init_cfg) + self.widen_factor = widen_factor + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 8): + raise ValueError('the item in out_indices must in ' + f'range(0, 8). But received {index}') + + if frozen_stages not in range(-1, 8): + raise ValueError('frozen_stages must be in range(-1, 8). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks, stride = layer_cfg + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + if widen_factor > 1.0: + self.out_channel = int(1280 * widen_factor) + else: + self.out_channel = 1280 + + layer = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channel, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.add_module('conv2', layer) + self.layers.append('conv2') + + def make_layer(self, out_channels, num_blocks, stride, expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. Default: 6. + """ + layers = [] + for i in range(num_blocks): + if i >= 1: + stride = 1 + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride, + expand_ratio=expand_ratio, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/mobilenet_v3.py b/mmpretrain/models/backbones/mobilenet_v3.py new file mode 100644 index 0000000..577dba9 --- /dev/null +++ b/mmpretrain/models/backbones/mobilenet_v3.py @@ -0,0 +1,217 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import ConvModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import InvertedResidual +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class MobileNetV3(BaseBackbone): + """MobileNetV3 backbone. + + Args: + arch (str): Architecture of mobilnetv3, from {small, large}. + Default: small. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + out_indices (None or Sequence[int]): Output from which stages. + Default: None, which means output tensors from final stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + """ + # Parameters to build each block: + # [kernel size, mid channels, out channels, with_se, act type, stride] + arch_settings = { + 'small': [[3, 16, 16, True, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 2], + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 40, True, 'HSwish', 2], + [5, 240, 40, True, 'HSwish', 1], + [5, 240, 40, True, 'HSwish', 1], + [5, 120, 48, True, 'HSwish', 1], + [5, 144, 48, True, 'HSwish', 1], + [5, 288, 96, True, 'HSwish', 2], + [5, 576, 96, True, 'HSwish', 1], + [5, 576, 96, True, 'HSwish', 1]], + 'small_075': [[3, 16, 16, True, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 2], + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 32, True, 'HSwish', 2], + [5, 192, 32, True, 'HSwish', 1], + [5, 192, 32, True, 'HSwish', 1], + [5, 96, 40, True, 'HSwish', 1], + [5, 120, 40, True, 'HSwish', 1], + [5, 240, 72, True, 'HSwish', 2], + [5, 432, 72, True, 'HSwish', 1], + [5, 432, 72, True, 'HSwish', 1]], + 'small_050': [[3, 16, 8, True, 'ReLU', 2], + [3, 40, 16, False, 'ReLU', 2], + [3, 56, 16, False, 'ReLU', 1], + [5, 64, 24, True, 'HSwish', 2], + [5, 144, 24, True, 'HSwish', 1], + [5, 144, 24, True, 'HSwish', 1], + [5, 72, 24, True, 'HSwish', 1], + [5, 72, 24, True, 'HSwish', 1], + [5, 144, 48, True, 'HSwish', 2], + [5, 288, 48, True, 'HSwish', 1], + [5, 288, 48, True, 'HSwish', 1]], + 'large': [[3, 16, 16, False, 'ReLU', 1], + [3, 64, 24, False, 'ReLU', 2], + [3, 72, 24, False, 'ReLU', 1], + [5, 72, 40, True, 'ReLU', 2], + [5, 120, 40, True, 'ReLU', 1], + [5, 120, 40, True, 'ReLU', 1], + [3, 240, 80, False, 'HSwish', 2], + [3, 200, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 480, 112, True, 'HSwish', 1], + [3, 672, 112, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 2], + [5, 960, 160, True, 'HSwish', 1], + [5, 960, 160, True, 'HSwish', 1]] + } # yapf: disable + + def __init__(self, + arch='small', + conv_cfg=None, + norm_cfg=dict(type='BN', eps=0.001, momentum=0.01), + out_indices=None, + frozen_stages=-1, + norm_eval=False, + with_cp=False, + init_cfg=[ + dict( + type='Kaiming', + layer=['Conv2d'], + nonlinearity='leaky_relu'), + dict(type='Normal', layer=['Linear'], std=0.01), + dict(type='Constant', layer=['BatchNorm2d'], val=1) + ]): + super(MobileNetV3, self).__init__(init_cfg) + assert arch in self.arch_settings + if out_indices is None: + out_indices = (12, ) if 'small' in arch else (16, ) + for order, index in enumerate(out_indices): + if index not in range(0, len(self.arch_settings[arch]) + 2): + raise ValueError( + 'the item in out_indices must in ' + f'range(0, {len(self.arch_settings[arch]) + 2}). ' + f'But received {index}') + + if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): + raise ValueError('frozen_stages must be in range(-1, ' + f'{len(self.arch_settings[arch]) + 2}). ' + f'But received {frozen_stages}') + self.arch = arch + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.layers = self._make_layer() + self.feat_dim = self.arch_settings[arch][-1][1] + + def _make_layer(self): + layers = [] + layer_setting = self.arch_settings[self.arch] + in_channels = 16 + + layer = ConvModule( + in_channels=3, + out_channels=in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + self.add_module('layer0', layer) + layers.append('layer0') + + for i, params in enumerate(layer_setting): + (kernel_size, mid_channels, out_channels, with_se, act, + stride) = params + if with_se: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), + dict( + type='HSigmoid', + bias=3, + divisor=6, + min_value=0, + max_value=1))) + else: + se_cfg = None + + layer = InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=act), + with_cp=self.with_cp) + in_channels = out_channels + layer_name = 'layer{}'.format(i + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + # Build the last layer before pooling + # TODO: No dilation + layer = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + layer_name = 'layer{}'.format(len(layer_setting) + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + return layers + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(0, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV3, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/mobileone.py b/mmpretrain/models/backbones/mobileone.py new file mode 100644 index 0000000..1111441 --- /dev/null +++ b/mmpretrain/models/backbones/mobileone.py @@ -0,0 +1,515 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from official impl https://github.com/apple/ml-mobileone/blob/main/mobileone.py # noqa: E501 +from typing import Optional, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils.se_layer import SELayer +from .base_backbone import BaseBackbone + + +class MobileOneBlock(BaseModule): + """MobileOne block for MobileOne backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + kernel_size (int): The kernel size of the convs in the block. If the + kernel size is large than 1, there will be a ``branch_scale`` in + the block. + num_convs (int): Number of the convolution branches in the block. + stride (int): Stride of convolution layers. Defaults to 1. + padding (int): Padding of the convolution layers. Defaults to 1. + dilation (int): Dilation of the convolution layers. Defaults to 1. + groups (int): Groups of the convolution layers. Defaults to 1. + se_cfg (None or dict): The configuration of the se module. + Defaults to None. + norm_cfg (dict): Configuration to construct and config norm layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + deploy (bool): Whether the model structure is in the deployment mode. + Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: int, + num_convs: int, + stride: int = 1, + padding: int = 1, + dilation: int = 1, + groups: int = 1, + se_cfg: Optional[dict] = None, + conv_cfg: Optional[dict] = None, + norm_cfg: Optional[dict] = dict(type='BN'), + act_cfg: Optional[dict] = dict(type='ReLU'), + deploy: bool = False, + init_cfg: Optional[dict] = None): + super(MobileOneBlock, self).__init__(init_cfg) + + assert se_cfg is None or isinstance(se_cfg, dict) + if se_cfg is not None: + self.se = SELayer(channels=out_channels, **se_cfg) + else: + self.se = nn.Identity() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.num_conv_branches = num_convs + self.stride = stride + self.padding = padding + self.se_cfg = se_cfg + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.deploy = deploy + self.groups = groups + self.dilation = dilation + + if deploy: + self.branch_reparam = build_conv_layer( + conv_cfg, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + groups=self.groups, + stride=stride, + padding=padding, + dilation=dilation, + bias=True) + else: + # judge if input shape and output shape are the same. + # If true, add a normalized identity shortcut. + if out_channels == in_channels and stride == 1: + self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1] + else: + self.branch_norm = None + + self.branch_scale = None + if kernel_size > 1: + self.branch_scale = self.create_conv_bn(kernel_size=1) + + self.branch_conv_list = ModuleList() + for _ in range(num_convs): + self.branch_conv_list.append( + self.create_conv_bn( + kernel_size=kernel_size, + padding=padding, + dilation=dilation)) + + self.act = build_activation_layer(act_cfg) + + def create_conv_bn(self, kernel_size, dilation=1, padding=0): + """cearte a (conv + bn) Sequential layer.""" + conv_bn = Sequential() + conv_bn.add_module( + 'conv', + build_conv_layer( + self.conv_cfg, + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + groups=self.groups, + stride=self.stride, + dilation=dilation, + padding=padding, + bias=False)) + conv_bn.add_module( + 'norm', + build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1]) + + return conv_bn + + def forward(self, x): + + def _inner_forward(inputs): + if self.deploy: + return self.branch_reparam(inputs) + + inner_out = 0 + if self.branch_norm is not None: + inner_out = self.branch_norm(inputs) + + if self.branch_scale is not None: + inner_out += self.branch_scale(inputs) + + for branch_conv in self.branch_conv_list: + inner_out += branch_conv(inputs) + + return inner_out + + return self.act(self.se(_inner_forward(x))) + + def switch_to_deploy(self): + """Switch the model structure from training mode to deployment mode.""" + if self.deploy: + return + assert self.norm_cfg['type'] == 'BN', \ + "Switch is not allowed when norm_cfg['type'] != 'BN'." + + reparam_weight, reparam_bias = self.reparameterize() + self.branch_reparam = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + bias=True) + self.branch_reparam.weight.data = reparam_weight + self.branch_reparam.bias.data = reparam_bias + + for param in self.parameters(): + param.detach_() + delattr(self, 'branch_conv_list') + if hasattr(self, 'branch_scale'): + delattr(self, 'branch_scale') + delattr(self, 'branch_norm') + + self.deploy = True + + def reparameterize(self): + """Fuse all the parameters of all branches. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all + branches. the first element is the weights and the second is + the bias. + """ + weight_conv, bias_conv = 0, 0 + for branch_conv in self.branch_conv_list: + weight, bias = self._fuse_conv_bn(branch_conv) + weight_conv += weight + bias_conv += bias + + weight_scale, bias_scale = 0, 0 + if self.branch_scale is not None: + weight_scale, bias_scale = self._fuse_conv_bn(self.branch_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + weight_scale = F.pad(weight_scale, [pad, pad, pad, pad]) + + weight_norm, bias_norm = 0, 0 + if self.branch_norm: + tmp_conv_bn = self._norm_to_conv(self.branch_norm) + weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn) + + return (weight_conv + weight_scale + weight_norm, + bias_conv + bias_scale + bias_norm) + + def _fuse_conv_bn(self, branch): + """Fuse the parameters in a branch with a conv and bn. + + Args: + branch (mmcv.runner.Sequential): A branch with conv and bn. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The parameters obtained after + fusing the parameters of conv and bn in one branch. + The first element is the weight and the second is the bias. + """ + if branch is None: + return 0, 0 + kernel = branch.conv.weight + running_mean = branch.norm.running_mean + running_var = branch.norm.running_var + gamma = branch.norm.weight + beta = branch.norm.bias + eps = branch.norm.eps + + std = (running_var + eps).sqrt() + fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * kernel + fused_bias = beta - running_mean * gamma / std + + return fused_weight, fused_bias + + def _norm_to_conv(self, branch_nrom): + """Convert a norm layer to a conv-bn sequence towards + ``self.kernel_size``. + + Args: + branch (nn.BatchNorm2d): A branch only with bn in the block. + + Returns: + (mmcv.runner.Sequential): a sequential with conv and bn. + """ + input_dim = self.in_channels // self.groups + conv_weight = torch.zeros( + (self.in_channels, input_dim, self.kernel_size, self.kernel_size), + dtype=branch_nrom.weight.dtype) + + for i in range(self.in_channels): + conv_weight[i, i % input_dim, self.kernel_size // 2, + self.kernel_size // 2] = 1 + conv_weight = conv_weight.to(branch_nrom.weight.device) + + tmp_conv = self.create_conv_bn(kernel_size=self.kernel_size) + tmp_conv.conv.weight.data = conv_weight + tmp_conv.norm = branch_nrom + return tmp_conv + + +@MODELS.register_module() +class MobileOne(BaseBackbone): + """MobileOne backbone. + + A PyTorch impl of : `An Improved One millisecond Mobile Backbone + `_ + + Args: + arch (str | dict): MobileOne architecture. If use string, choose + from 's0', 's1', 's2', 's3' and 's4'. If use dict, it should + have below keys: + + - num_blocks (Sequence[int]): Number of blocks in each stage. + - width_factor (Sequence[float]): Width factor in each stage. + - num_conv_branches (Sequence[int]): Number of conv branches + in each stage. + - num_se_blocks (Sequence[int]): Number of SE layers in each + stage, all the SE layers are placed in the subsequent order + in each stage. + + Defaults to 's0'. + in_channels (int): Number of input image channels. Default: 3. + out_indices (Sequence[int] | int): Output from which stages. + Defaults to ``(3, )``. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. Defaults to -1. + conv_cfg (dict | None): The config dict for conv layers. + Defaults to None. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + deploy (bool): Whether to switch the model structure to deployment + mode. Defaults to False. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + + Example: + >>> from mmpretrain.models import MobileOne + >>> import torch + >>> x = torch.rand(1, 3, 224, 224) + >>> model = MobileOne("s0", out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> outputs = model(x) + >>> for out in outputs: + ... print(tuple(out.shape)) + (1, 48, 56, 56) + (1, 128, 28, 28) + (1, 256, 14, 14) + (1, 1024, 7, 7) + """ + + arch_zoo = { + 's0': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[0.75, 1.0, 1.0, 2.0], + num_conv_branches=[4, 4, 4, 4], + num_se_blocks=[0, 0, 0, 0]), + 's1': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[1.5, 1.5, 2.0, 2.5], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 0, 0]), + 's2': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[1.5, 2.0, 2.5, 4.0], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 0, 0]), + 's3': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[2.0, 2.5, 3.0, 4.0], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 0, 0]), + 's4': + dict( + num_blocks=[2, 8, 10, 1], + width_factor=[3.0, 3.5, 3.5, 4.0], + num_conv_branches=[1, 1, 1, 1], + num_se_blocks=[0, 0, 5, 1]) + } + + def __init__(self, + arch, + in_channels=3, + out_indices=(3, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + se_cfg=dict(ratio=16), + deploy=False, + norm_eval=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm']) + ]): + super(MobileOne, self).__init__(init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_zoo, f'"arch": "{arch}"' \ + f' is not one of the {list(self.arch_zoo.keys())}' + arch = self.arch_zoo[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + self.arch = arch + for k, value in self.arch.items(): + assert isinstance(value, list) and len(value) == 4, \ + f'the value of {k} in arch must be list with 4 items.' + + self.in_channels = in_channels + self.deploy = deploy + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.se_cfg = se_cfg + self.act_cfg = act_cfg + + base_channels = [64, 128, 256, 512] + channels = min(64, + int(base_channels[0] * self.arch['width_factor'][0])) + self.stage0 = MobileOneBlock( + self.in_channels, + channels, + stride=2, + kernel_size=3, + num_convs=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + deploy=deploy) + + self.in_planes = channels + self.stages = [] + for i, num_blocks in enumerate(self.arch['num_blocks']): + planes = int(base_channels[i] * self.arch['width_factor'][i]) + + stage = self._make_stage(planes, num_blocks, + arch['num_se_blocks'][i], + arch['num_conv_branches'][i]) + + stage_name = f'stage{i + 1}' + self.add_module(stage_name, stage) + self.stages.append(stage_name) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = len(self.stages) + index + assert 0 <= out_indices[i] <= len(self.stages), \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + def _make_stage(self, planes, num_blocks, num_se, num_conv_branches): + strides = [2] + [1] * (num_blocks - 1) + if num_se > num_blocks: + raise ValueError('Number of SE blocks cannot ' + 'exceed number of layers.') + blocks = [] + for i in range(num_blocks): + use_se = False + if i >= (num_blocks - num_se): + use_se = True + + blocks.append( + # Depthwise conv + MobileOneBlock( + in_channels=self.in_planes, + out_channels=self.in_planes, + kernel_size=3, + num_convs=num_conv_branches, + stride=strides[i], + padding=1, + groups=self.in_planes, + se_cfg=self.se_cfg if use_se else None, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + deploy=self.deploy)) + + blocks.append( + # Pointwise conv + MobileOneBlock( + in_channels=self.in_planes, + out_channels=planes, + kernel_size=1, + num_convs=num_conv_branches, + stride=1, + padding=0, + se_cfg=self.se_cfg if use_se else None, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + deploy=self.deploy)) + + self.in_planes = planes + + return Sequential(*blocks) + + def forward(self, x): + x = self.stage0(x) + outs = [] + for i, stage_name in enumerate(self.stages): + stage = getattr(self, stage_name) + x = stage(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stage0.eval() + for param in self.stage0.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + stage = getattr(self, f'stage{i+1}') + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + """switch the mobile to train mode or not.""" + super(MobileOne, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def switch_to_deploy(self): + """switch the model to deploy mode, which has smaller amount of + parameters and calculations.""" + for m in self.modules(): + if isinstance(m, MobileOneBlock): + m.switch_to_deploy() + self.deploy = True diff --git a/mmpretrain/models/backbones/mobilevit.py b/mmpretrain/models/backbones/mobilevit.py new file mode 100644 index 0000000..9e4043f --- /dev/null +++ b/mmpretrain/models/backbones/mobilevit.py @@ -0,0 +1,431 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Callable, Optional, Sequence + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_norm_layer +from torch import nn + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone +from .mobilenet_v2 import InvertedResidual +from .vision_transformer import TransformerEncoderLayer + + +class MobileVitBlock(nn.Module): + """MobileViT block. + + According to the paper, the MobileViT block has a local representation. + a transformer-as-convolution layer which consists of a global + representation with unfolding and folding, and a final fusion layer. + + Args: + in_channels (int): Number of input image channels. + transformer_dim (int): Number of transformer channels. + ffn_dim (int): Number of ffn channels in transformer block. + out_channels (int): Number of channels in output. + conv_ksize (int): Conv kernel size in local representation + and fusion. Defaults to 3. + conv_cfg (dict, optional): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer. + Defaults to dict(type='Swish'). + num_transformer_blocks (int): Number of transformer blocks in + a MobileViT block. Defaults to 2. + patch_size (int): Patch size for unfolding and folding. + Defaults to 2. + num_heads (int): Number of heads in global representation. + Defaults to 4. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + no_fusion (bool): Whether to remove the fusion layer. + Defaults to False. + transformer_norm_cfg (dict, optional): Config dict for normalization + layer in transformer. Defaults to dict(type='LN'). + """ + + def __init__( + self, + in_channels: int, + transformer_dim: int, + ffn_dim: int, + out_channels: int, + conv_ksize: int = 3, + conv_cfg: Optional[dict] = None, + norm_cfg: Optional[dict] = dict(type='BN'), + act_cfg: Optional[dict] = dict(type='Swish'), + num_transformer_blocks: int = 2, + patch_size: int = 2, + num_heads: int = 4, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + no_fusion: bool = False, + transformer_norm_cfg: Callable = dict(type='LN'), + ): + super(MobileVitBlock, self).__init__() + + self.local_rep = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=conv_ksize, + padding=int((conv_ksize - 1) / 2), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=in_channels, + out_channels=transformer_dim, + kernel_size=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=None, + act_cfg=None), + ) + + global_rep = [ + TransformerEncoderLayer( + embed_dims=transformer_dim, + num_heads=num_heads, + feedforward_channels=ffn_dim, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + qkv_bias=True, + act_cfg=dict(type='Swish'), + norm_cfg=transformer_norm_cfg) + for _ in range(num_transformer_blocks) + ] + global_rep.append( + build_norm_layer(transformer_norm_cfg, transformer_dim)[1]) + self.global_rep = nn.Sequential(*global_rep) + + self.conv_proj = ConvModule( + in_channels=transformer_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if no_fusion: + self.conv_fusion = None + else: + self.conv_fusion = ConvModule( + in_channels=in_channels + out_channels, + out_channels=out_channels, + kernel_size=conv_ksize, + padding=int((conv_ksize - 1) / 2), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.patch_size = (patch_size, patch_size) + self.patch_area = self.patch_size[0] * self.patch_size[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + + # Local representation + x = self.local_rep(x) + + # Unfold (feature map -> patches) + patch_h, patch_w = self.patch_size + B, C, H, W = x.shape + new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil( + W / patch_w) * patch_w + num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w # noqa + num_patches = num_patch_h * num_patch_w # N + interpolate = False + if new_h != H or new_w != W: + # Note: Padding can be done, but then it needs to be handled in attention function. # noqa + x = F.interpolate( + x, size=(new_h, new_w), mode='bilinear', align_corners=False) + interpolate = True + + # [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w] + x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, + patch_w).transpose(1, 2) + # [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w # noqa + x = x.reshape(B, C, num_patches, + self.patch_area).transpose(1, 3).reshape( + B * self.patch_area, num_patches, -1) + + # Global representations + x = self.global_rep(x) + + # Fold (patch -> feature map) + # [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w] + x = x.contiguous().view(B, self.patch_area, num_patches, -1) + x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, + patch_h, patch_w) + # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] # noqa + x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, + num_patch_w * patch_w) + if interpolate: + x = F.interpolate( + x, size=(H, W), mode='bilinear', align_corners=False) + + x = self.conv_proj(x) + if self.conv_fusion is not None: + x = self.conv_fusion(torch.cat((shortcut, x), dim=1)) + return x + + +@MODELS.register_module() +class MobileViT(BaseBackbone): + """MobileViT backbone. + + A PyTorch implementation of : `MobileViT: Light-weight, General-purpose, + and Mobile-friendly Vision Transformer `_ + + Modified from the `official repo + `_ + and `timm + `_. + + Args: + arch (str | List[list]): Architecture of MobileViT. + + - If a string, choose from "small", "x_small" and "xx_small". + + - If a list, every item should be also a list, and the first item + of the sub-list can be chosen from "moblienetv2" and "mobilevit", + which indicates the type of this layer sequence. If "mobilenetv2", + the other items are the arguments of :attr:`~MobileViT.make_mobilenetv2_layer` + (except ``in_channels``) and if "mobilevit", the other items are + the arguments of :attr:`~MobileViT.make_mobilevit_layer` + (except ``in_channels``). + + Defaults to "small". + in_channels (int): Number of input image channels. Defaults to 3. + stem_channels (int): Channels of stem layer. Defaults to 16. + last_exp_factor (int): Channels expand factor of last layer. + Defaults to 4. + out_indices (Sequence[int]): Output from which stages. + Defaults to (4, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Defaults to None, which means using conv2d. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer. + Defaults to dict(type='Swish'). + init_cfg (dict, optional): Initialization config dict. + """ # noqa + + # Parameters to build layers. The first param is the type of layer. + # For `mobilenetv2` layer, the rest params from left to right are: + # out channels, stride, num of blocks, expand_ratio. + # For `mobilevit` layer, the rest params from left to right are: + # out channels, stride, transformer_channels, ffn channels, + # num of transformer blocks, expand_ratio. + arch_settings = { + 'small': [ + ['mobilenetv2', 32, 1, 1, 4], + ['mobilenetv2', 64, 2, 3, 4], + ['mobilevit', 96, 2, 144, 288, 2, 4], + ['mobilevit', 128, 2, 192, 384, 4, 4], + ['mobilevit', 160, 2, 240, 480, 3, 4], + ], + 'x_small': [ + ['mobilenetv2', 32, 1, 1, 4], + ['mobilenetv2', 48, 2, 3, 4], + ['mobilevit', 64, 2, 96, 192, 2, 4], + ['mobilevit', 80, 2, 120, 240, 4, 4], + ['mobilevit', 96, 2, 144, 288, 3, 4], + ], + 'xx_small': [ + ['mobilenetv2', 16, 1, 1, 2], + ['mobilenetv2', 24, 2, 3, 2], + ['mobilevit', 48, 2, 64, 128, 2, 2], + ['mobilevit', 64, 2, 80, 160, 4, 2], + ['mobilevit', 80, 2, 96, 192, 3, 2], + ] + } + + def __init__(self, + arch='small', + in_channels=3, + stem_channels=16, + last_exp_factor=4, + out_indices=(4, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='Swish'), + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(MobileViT, self).__init__(init_cfg) + if isinstance(arch, str): + arch = arch.lower() + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a list.' + arch = self.arch_settings[arch] + + self.arch = arch + self.num_stages = len(arch) + + # check out indices and frozen stages + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + if frozen_stages not in range(-1, self.num_stages): + raise ValueError('frozen_stages must be in range(-1, ' + f'{self.num_stages}). ' + f'But received {frozen_stages}') + self.frozen_stages = frozen_stages + + _make_layer_func = { + 'mobilenetv2': self.make_mobilenetv2_layer, + 'mobilevit': self.make_mobilevit_layer, + } + + self.stem = ConvModule( + in_channels=in_channels, + out_channels=stem_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + in_channels = stem_channels + layers = [] + for i, layer_settings in enumerate(arch): + layer_type, settings = layer_settings[0], layer_settings[1:] + layer, out_channels = _make_layer_func[layer_type](in_channels, + *settings) + layers.append(layer) + in_channels = out_channels + self.layers = nn.Sequential(*layers) + + self.conv_1x1_exp = ConvModule( + in_channels=in_channels, + out_channels=last_exp_factor * in_channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + @staticmethod + def make_mobilevit_layer(in_channels, + out_channels, + stride, + transformer_dim, + ffn_dim, + num_transformer_blocks, + expand_ratio=4): + """Build mobilevit layer, which consists of one InvertedResidual and + one MobileVitBlock. + + Args: + in_channels (int): The input channels. + out_channels (int): The output channels. + stride (int): The stride of the first 3x3 convolution in the + ``InvertedResidual`` layers. + transformer_dim (int): The channels of the transformer layers. + ffn_dim (int): The mid-channels of the feedforward network in + transformer layers. + num_transformer_blocks (int): The number of transformer blocks. + expand_ratio (int): adjusts number of channels of the hidden layer + in ``InvertedResidual`` by this amount. Defaults to 4. + """ + layer = [] + layer.append( + InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + act_cfg=dict(type='Swish'), + )) + layer.append( + MobileVitBlock( + in_channels=out_channels, + transformer_dim=transformer_dim, + ffn_dim=ffn_dim, + out_channels=out_channels, + num_transformer_blocks=num_transformer_blocks, + )) + return nn.Sequential(*layer), out_channels + + @staticmethod + def make_mobilenetv2_layer(in_channels, + out_channels, + stride, + num_blocks, + expand_ratio=4): + """Build mobilenetv2 layer, which consists of several InvertedResidual + layers. + + Args: + in_channels (int): The input channels. + out_channels (int): The output channels. + stride (int): The stride of the first 3x3 convolution in the + ``InvertedResidual`` layers. + num_blocks (int): The number of ``InvertedResidual`` blocks. + expand_ratio (int): adjusts number of channels of the hidden layer + in ``InvertedResidual`` by this amount. Defaults to 4. + """ + layer = [] + for i in range(num_blocks): + stride = stride if i == 0 else 1 + + layer.append( + InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + act_cfg=dict(type='Swish'), + )) + in_channels = out_channels + return nn.Sequential(*layer), out_channels + + def _freeze_stages(self): + for i in range(0, self.frozen_stages): + layer = self.layers[i] + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileViT, self).train(mode) + self._freeze_stages() + + def forward(self, x): + x = self.stem(x) + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + x = self.conv_1x1_exp(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/mvit.py b/mmpretrain/models/backbones/mvit.py new file mode 100644 index 0000000..68aee97 --- /dev/null +++ b/mmpretrain/models/backbones/mvit.py @@ -0,0 +1,700 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils import to_2tuple + +from ..builder import BACKBONES +from ..utils import resize_pos_embed +from .base_backbone import BaseBackbone + + +def resize_decomposed_rel_pos(rel_pos, q_size, k_size): + """Get relative positional embeddings according to the relative positions + of query and key sizes. + + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + resized = F.interpolate( + # (L, C) -> (1, C, L) + rel_pos.transpose(0, 1).unsqueeze(0), + size=max_rel_dist, + mode='linear', + ) + # (1, C, L) -> (L, C) + resized = resized.squeeze(0).transpose(0, 1) + else: + resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_h_ratio = max(k_size / q_size, 1.0) + k_h_ratio = max(q_size / k_size, 1.0) + q_coords = torch.arange(q_size)[:, None] * q_h_ratio + k_coords = torch.arange(k_size)[None, :] * k_h_ratio + relative_coords = (q_coords - k_coords) + (k_size - 1) * k_h_ratio + + return resized[relative_coords.long()] + + +def add_decomposed_rel_pos(attn, + q, + q_shape, + k_shape, + rel_pos_h, + rel_pos_w, + has_cls_token=False): + """Spatial Relative Positional Embeddings.""" + sp_idx = 1 if has_cls_token else 0 + B, num_heads, _, C = q.shape + q_h, q_w = q_shape + k_h, k_w = k_shape + + Rh = resize_decomposed_rel_pos(rel_pos_h, q_h, k_h) + Rw = resize_decomposed_rel_pos(rel_pos_w, q_w, k_w) + + r_q = q[:, :, sp_idx:].reshape(B, num_heads, q_h, q_w, C) + rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh) + rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw) + rel_pos_embed = rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :] + + attn_map = attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) + attn_map += rel_pos_embed + attn[:, :, sp_idx:, sp_idx:] = attn_map.view(B, -1, q_h * q_w, k_h * k_w) + + return attn + + +class MLP(BaseModule): + """Two-layer multilayer perceptron. + + Comparing with :class:`mmcv.cnn.bricks.transformer.FFN`, this class allows + different input and output channel numbers. + + Args: + in_channels (int): The number of input channels. + hidden_channels (int, optional): The number of hidden layer channels. + If None, same as the ``in_channels``. Defaults to None. + out_channels (int, optional): The number of output channels. If None, + same as the ``in_channels``. Defaults to None. + act_cfg (dict): The config of activation function. + Defaults to ``dict(type='GELU')``. + init_cfg (dict, optional): The config of weight initialization. + Defaults to None. + """ + + def __init__(self, + in_channels, + hidden_channels=None, + out_channels=None, + act_cfg=dict(type='GELU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.fc1 = nn.Linear(in_channels, hidden_channels) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Linear(hidden_channels, out_channels) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +def attention_pool(x: torch.Tensor, + pool: nn.Module, + in_size: tuple, + norm: Optional[nn.Module] = None): + """Pooling the feature tokens. + + Args: + x (torch.Tensor): The input tensor, should be with shape + ``(B, num_heads, L, C)`` or ``(B, L, C)``. + pool (nn.Module): The pooling module. + in_size (Tuple[int]): The shape of the input feature map. + norm (nn.Module, optional): The normalization module. + Defaults to None. + """ + ndim = x.ndim + if ndim == 4: + B, num_heads, L, C = x.shape + elif ndim == 3: + num_heads = 1 + B, L, C = x.shape + else: + raise RuntimeError(f'Unsupported input dimension {x.shape}') + + H, W = in_size + assert L == H * W + + # (B, num_heads, H*W, C) -> (B*num_heads, C, H, W) + x = x.reshape(B * num_heads, H, W, C).permute(0, 3, 1, 2).contiguous() + x = pool(x) + out_size = x.shape[-2:] + + # (B*num_heads, C, H', W') -> (B, num_heads, H'*W', C) + x = x.reshape(B, num_heads, C, -1).transpose(2, 3) + + if norm is not None: + x = norm(x) + + if ndim == 3: + x = x.squeeze(1) + + return x, out_size + + +class MultiScaleAttention(BaseModule): + """Multiscale Multi-head Attention block. + + Args: + in_dims (int): Number of input channels. + out_dims (int): Number of output channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key and + value. Defaults to True. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='LN')``. + pool_kernel (tuple): kernel size for qkv pooling layers. + Defaults to (3, 3). + stride_q (int): stride size for q pooling layer. Defaults to 1. + stride_kv (int): stride size for kv pooling layer. Defaults to 1. + rel_pos_spatial (bool): Whether to enable the spatial relative + position embedding. Defaults to True. + residual_pooling (bool): Whether to enable the residual connection + after attention pooling. Defaults to True. + input_size (Tuple[int], optional): The input resolution, necessary + if enable the ``rel_pos_spatial``. Defaults to None. + rel_pos_zero_init (bool): If True, zero initialize relative + positional parameters. Defaults to False. + init_cfg (dict, optional): The config of weight initialization. + Defaults to None. + """ + + def __init__(self, + in_dims, + out_dims, + num_heads, + qkv_bias=True, + norm_cfg=dict(type='LN'), + pool_kernel=(3, 3), + stride_q=1, + stride_kv=1, + rel_pos_spatial=False, + residual_pooling=True, + input_size=None, + rel_pos_zero_init=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.num_heads = num_heads + self.in_dims = in_dims + self.out_dims = out_dims + + head_dim = out_dims // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(in_dims, out_dims * 3, bias=qkv_bias) + self.proj = nn.Linear(out_dims, out_dims) + + # qkv pooling + pool_padding = [k // 2 for k in pool_kernel] + pool_dims = out_dims // num_heads + + def build_pooling(stride): + pool = nn.Conv2d( + pool_dims, + pool_dims, + pool_kernel, + stride=stride, + padding=pool_padding, + groups=pool_dims, + bias=False, + ) + norm = build_norm_layer(norm_cfg, pool_dims)[1] + return pool, norm + + self.pool_q, self.norm_q = build_pooling(stride_q) + self.pool_k, self.norm_k = build_pooling(stride_kv) + self.pool_v, self.norm_v = build_pooling(stride_kv) + + self.residual_pooling = residual_pooling + + self.rel_pos_spatial = rel_pos_spatial + self.rel_pos_zero_init = rel_pos_zero_init + if self.rel_pos_spatial: + # initialize relative positional embeddings + assert input_size[0] == input_size[1] + + size = input_size[0] + rel_dim = 2 * max(size // stride_q, size // stride_kv) - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(rel_dim, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_dim, head_dim)) + + def init_weights(self): + """Weight initialization.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress rel_pos_zero_init if use pretrained model. + return + + if not self.rel_pos_zero_init: + trunc_normal_(self.rel_pos_h, std=0.02) + trunc_normal_(self.rel_pos_w, std=0.02) + + def forward(self, x, in_size): + """Forward the MultiScaleAttention.""" + B, N, _ = x.shape # (B, H*W, C) + + # qkv: (B, H*W, 3, num_heads, C) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1) + # q, k, v: (B, num_heads, H*W, C) + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) + + q, q_shape = attention_pool(q, self.pool_q, in_size, norm=self.norm_q) + k, k_shape = attention_pool(k, self.pool_k, in_size, norm=self.norm_k) + v, v_shape = attention_pool(v, self.pool_v, in_size, norm=self.norm_v) + + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.rel_pos_spatial: + attn = add_decomposed_rel_pos(attn, q, q_shape, k_shape, + self.rel_pos_h, self.rel_pos_w) + + attn = attn.softmax(dim=-1) + x = attn @ v + + if self.residual_pooling: + x = x + q + + # (B, num_heads, H'*W', C'//num_heads) -> (B, H'*W', C') + x = x.transpose(1, 2).reshape(B, -1, self.out_dims) + x = self.proj(x) + + return x, q_shape + + +class MultiScaleBlock(BaseModule): + """Multiscale Transformer blocks. + + Args: + in_dims (int): Number of input channels. + out_dims (int): Number of output channels. + num_heads (int): Number of attention heads. + mlp_ratio (float): Ratio of hidden dimensions in MLP layers. + Defaults to 4.0. + qkv_bias (bool): If True, add a learnable bias to query, key and + value. Defaults to True. + drop_path (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='LN')``. + act_cfg (dict): The config of activation function. + Defaults to ``dict(type='GELU')``. + qkv_pool_kernel (tuple): kernel size for qkv pooling layers. + Defaults to (3, 3). + stride_q (int): stride size for q pooling layer. Defaults to 1. + stride_kv (int): stride size for kv pooling layer. Defaults to 1. + rel_pos_spatial (bool): Whether to enable the spatial relative + position embedding. Defaults to True. + residual_pooling (bool): Whether to enable the residual connection + after attention pooling. Defaults to True. + dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in + attention layers. If False, multiply it in MLP layers. + Defaults to True. + input_size (Tuple[int], optional): The input resolution, necessary + if enable the ``rel_pos_spatial``. Defaults to None. + rel_pos_zero_init (bool): If True, zero initialize relative + positional parameters. Defaults to False. + init_cfg (dict, optional): The config of weight initialization. + Defaults to None. + """ + + def __init__( + self, + in_dims, + out_dims, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + qkv_pool_kernel=(3, 3), + stride_q=1, + stride_kv=1, + rel_pos_spatial=True, + residual_pooling=True, + dim_mul_in_attention=True, + input_size=None, + rel_pos_zero_init=False, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.in_dims = in_dims + self.out_dims = out_dims + self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] + self.dim_mul_in_attention = dim_mul_in_attention + + attn_dims = out_dims if dim_mul_in_attention else in_dims + self.attn = MultiScaleAttention( + in_dims, + attn_dims, + num_heads=num_heads, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + pool_kernel=qkv_pool_kernel, + stride_q=stride_q, + stride_kv=stride_kv, + rel_pos_spatial=rel_pos_spatial, + residual_pooling=residual_pooling, + input_size=input_size, + rel_pos_zero_init=rel_pos_zero_init) + self.drop_path = DropPath( + drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = build_norm_layer(norm_cfg, attn_dims)[1] + + self.mlp = MLP( + in_channels=attn_dims, + hidden_channels=int(attn_dims * mlp_ratio), + out_channels=out_dims, + act_cfg=act_cfg) + + if in_dims != out_dims: + self.proj = nn.Linear(in_dims, out_dims) + else: + self.proj = None + + if stride_q > 1: + kernel_skip = stride_q + 1 + padding_skip = int(kernel_skip // 2) + self.pool_skip = nn.MaxPool2d( + kernel_skip, stride_q, padding_skip, ceil_mode=False) + + if input_size is not None: + input_size = to_2tuple(input_size) + out_size = [size // stride_q for size in input_size] + self.init_out_size = out_size + else: + self.init_out_size = None + else: + self.pool_skip = None + self.init_out_size = input_size + + def forward(self, x, in_size): + x_norm = self.norm1(x) + x_attn, out_size = self.attn(x_norm, in_size) + + if self.dim_mul_in_attention and self.proj is not None: + skip = self.proj(x_norm) + else: + skip = x + + if self.pool_skip is not None: + skip, _ = attention_pool(skip, self.pool_skip, in_size) + + x = skip + self.drop_path(x_attn) + x_norm = self.norm2(x) + x_mlp = self.mlp(x_norm) + + if not self.dim_mul_in_attention and self.proj is not None: + skip = self.proj(x_norm) + else: + skip = x + + x = skip + self.drop_path(x_mlp) + + return x, out_size + + +@BACKBONES.register_module() +class MViT(BaseBackbone): + """Multi-scale ViT v2. + + A PyTorch implement of : `MViTv2: Improved Multiscale Vision Transformers + for Classification and Detection `_ + + Inspiration from `the official implementation + `_ and `the detectron2 + implementation `_ + + Args: + arch (str | dict): MViT architecture. If use string, choose + from 'tiny', 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of layers. + - **num_heads** (int): The number of heads in attention + modules of the initial layer. + - **downscale_indices** (List[int]): The layer indices to downscale + the feature map. + + Defaults to 'base'. + img_size (int): The expected input image shape. Defaults to 224. + in_channels (int): The num of input channels. Defaults to 3. + out_scales (int | Sequence[int]): The output scale indices. + They should not exceed the length of ``downscale_indices``. + Defaults to -1, which means the last scale. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + interpolate_mode (str): Select the interpolate mode for absolute + position embedding vector resize. Defaults to "bicubic". + pool_kernel (tuple): kernel size for qkv pooling layers. + Defaults to (3, 3). + dim_mul (int): The magnification for ``embed_dims`` in the downscale + layers. Defaults to 2. + head_mul (int): The magnification for ``num_heads`` in the downscale + layers. Defaults to 2. + adaptive_kv_stride (int): The stride size for kv pooling in the initial + layer. Defaults to 4. + rel_pos_spatial (bool): Whether to enable the spatial relative position + embedding. Defaults to True. + residual_pooling (bool): Whether to enable the residual connection + after attention pooling. Defaults to True. + dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in + attention layers. If False, multiply it in MLP layers. + Defaults to True. + rel_pos_zero_init (bool): If True, zero initialize relative + positional parameters. Defaults to False. + mlp_ratio (float): Ratio of hidden dimensions in MLP layers. + Defaults to 4.0. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN', eps=1e-6)``. + patch_cfg (dict): Config dict for the patch embedding layer. + Defaults to ``dict(kernel_size=7, stride=4, padding=3)``. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> import torch + >>> from mmpretrain.models import build_backbone + >>> + >>> cfg = dict(type='MViT', arch='tiny', out_scales=[0, 1, 2, 3]) + >>> model = build_backbone(cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> outputs = model(inputs) + >>> for i, output in enumerate(outputs): + >>> print(f'scale{i}: {output.shape}') + scale0: torch.Size([1, 96, 56, 56]) + scale1: torch.Size([1, 192, 28, 28]) + scale2: torch.Size([1, 384, 14, 14]) + scale3: torch.Size([1, 768, 7, 7]) + """ + arch_zoo = { + 'tiny': { + 'embed_dims': 96, + 'num_layers': 10, + 'num_heads': 1, + 'downscale_indices': [1, 3, 8] + }, + 'small': { + 'embed_dims': 96, + 'num_layers': 16, + 'num_heads': 1, + 'downscale_indices': [1, 3, 14] + }, + 'base': { + 'embed_dims': 96, + 'num_layers': 24, + 'num_heads': 1, + 'downscale_indices': [2, 5, 21] + }, + 'large': { + 'embed_dims': 144, + 'num_layers': 48, + 'num_heads': 2, + 'downscale_indices': [2, 8, 44] + }, + } + num_extra_tokens = 0 + + def __init__(self, + arch='base', + img_size=224, + in_channels=3, + out_scales=-1, + drop_path_rate=0., + use_abs_pos_embed=False, + interpolate_mode='bicubic', + pool_kernel=(3, 3), + dim_mul=2, + head_mul=2, + adaptive_kv_stride=4, + rel_pos_spatial=True, + residual_pooling=True, + dim_mul_in_attention=True, + rel_pos_zero_init=False, + mlp_ratio=4., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + patch_cfg=dict(kernel_size=7, stride=4, padding=3), + init_cfg=None): + super().__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'downscale_indices' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.num_heads = self.arch_settings['num_heads'] + self.downscale_indices = self.arch_settings['downscale_indices'] + self.num_scales = len(self.downscale_indices) + 1 + self.stage_indices = { + index - 1: i + for i, index in enumerate(self.downscale_indices) + } + self.stage_indices[self.num_layers - 1] = self.num_scales - 1 + self.use_abs_pos_embed = use_abs_pos_embed + self.interpolate_mode = interpolate_mode + + if isinstance(out_scales, int): + out_scales = [out_scales] + assert isinstance(out_scales, Sequence), \ + f'"out_scales" must by a sequence or int, ' \ + f'get {type(out_scales)} instead.' + for i, index in enumerate(out_scales): + if index < 0: + out_scales[i] = self.num_scales + index + assert 0 <= out_scales[i] <= self.num_scales, \ + f'Invalid out_scales {index}' + self.out_scales = sorted(list(out_scales)) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + # Set absolute position embedding + if self.use_abs_pos_embed: + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.blocks = ModuleList() + out_dims_list = [self.embed_dims] + num_heads = self.num_heads + stride_kv = adaptive_kv_stride + input_size = self.patch_resolution + for i in range(self.num_layers): + if i in self.downscale_indices: + num_heads *= head_mul + stride_q = 2 + stride_kv = max(stride_kv // 2, 1) + else: + stride_q = 1 + + # Set output embed_dims + if dim_mul_in_attention and i in self.downscale_indices: + # multiply embed_dims in downscale layers. + out_dims = out_dims_list[-1] * dim_mul + elif not dim_mul_in_attention and i + 1 in self.downscale_indices: + # multiply embed_dims before downscale layers. + out_dims = out_dims_list[-1] * dim_mul + else: + out_dims = out_dims_list[-1] + + attention_block = MultiScaleBlock( + in_dims=out_dims_list[-1], + out_dims=out_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_cfg=norm_cfg, + qkv_pool_kernel=pool_kernel, + stride_q=stride_q, + stride_kv=stride_kv, + rel_pos_spatial=rel_pos_spatial, + residual_pooling=residual_pooling, + dim_mul_in_attention=dim_mul_in_attention, + input_size=input_size, + rel_pos_zero_init=rel_pos_zero_init) + self.blocks.append(attention_block) + + input_size = attention_block.init_out_size + out_dims_list.append(out_dims) + + if i in self.stage_indices: + stage_index = self.stage_indices[i] + if stage_index in self.out_scales: + norm_layer = build_norm_layer(norm_cfg, out_dims)[1] + self.add_module(f'norm{stage_index}', norm_layer) + + def init_weights(self): + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.pos_embed, std=0.02) + + def forward(self, x): + """Forward the MViT.""" + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + + outs = [] + for i, block in enumerate(self.blocks): + x, patch_resolution = block(x, patch_resolution) + + if i in self.stage_indices: + stage_index = self.stage_indices[i] + if stage_index in self.out_scales: + B, _, C = x.shape + x = getattr(self, f'norm{stage_index}')(x) + out = x.transpose(1, 2).reshape(B, C, *patch_resolution) + outs.append(out.contiguous()) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/poolformer.py b/mmpretrain/models/backbones/poolformer.py new file mode 100644 index 0000000..e2ad670 --- /dev/null +++ b/mmpretrain/models/backbones/poolformer.py @@ -0,0 +1,416 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class PatchEmbed(nn.Module): + """Patch Embedding module implemented by a layer of convolution. + + Input: tensor in shape [B, C, H, W] + Output: tensor in shape [B, C, H/stride, W/stride] + Args: + patch_size (int): Patch size of the patch embedding. Defaults to 16. + stride (int): Stride of the patch embedding. Defaults to 16. + padding (int): Padding of the patch embedding. Defaults to 0. + in_chans (int): Input channels. Defaults to 3. + embed_dim (int): Output dimension of the patch embedding. + Defaults to 768. + norm_layer (module): Normalization module. Defaults to None (not use). + """ + + def __init__(self, + patch_size=16, + stride=16, + padding=0, + in_chans=3, + embed_dim=768, + norm_layer=None): + super().__init__() + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=padding) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + + +class Pooling(nn.Module): + """Pooling module. + + Args: + pool_size (int): Pooling size. Defaults to 3. + """ + + def __init__(self, pool_size=3): + super().__init__() + self.pool = nn.AvgPool2d( + pool_size, + stride=1, + padding=pool_size // 2, + count_include_pad=False) + + def forward(self, x): + return self.pool(x) - x + + +class Mlp(nn.Module): + """Mlp implemented by with 1*1 convolutions. + + Input: Tensor with shape [B, C, H, W]. + Output: Tensor with shape [B, C, H, W]. + Args: + in_features (int): Dimension of input features. + hidden_features (int): Dimension of hidden features. + out_features (int): Dimension of output features. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0.0. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class PoolFormerBlock(BaseModule): + """PoolFormer Block. + + Args: + dim (int): Embedding dim. + pool_size (int): Pooling size. Defaults to 3. + mlp_ratio (float): Mlp expansion ratio. Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='GN', num_groups=1)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-5. + """ + + def __init__(self, + dim, + pool_size=3, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + layer_scale_init_value=1e-5): + + super().__init__() + + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + self.token_mixer = Pooling(pool_size=pool_size) + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + # The following two techniques are useful to train deep PoolFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + + def forward(self, x): + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.token_mixer(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + return x + + +def basic_blocks(dim, + index, + layers, + pool_size=3, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop_rate=.0, + drop_path_rate=0., + layer_scale_init_value=1e-5): + """ + generate PoolFormer blocks for a stage + return: PoolFormer blocks + """ + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / ( + sum(layers) - 1) + blocks.append( + PoolFormerBlock( + dim, + pool_size=pool_size, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + layer_scale_init_value=layer_scale_init_value, + )) + blocks = nn.Sequential(*blocks) + + return blocks + + +@MODELS.register_module() +class PoolFormer(BaseBackbone): + """PoolFormer. + + A PyTorch implementation of PoolFormer introduced by: + `MetaFormer is Actually What You Need for Vision `_ + + Modified from the `official repo + `. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``PoolFormer.arch_settings``. And if dict, it + should include the following two keys: + + - layers (list[int]): Number of blocks at each stage. + - embed_dims (list[int]): The number of channels at each stage. + - mlp_ratios (list[int]): Expansion ratio of MLPs. + - layer_scale_init_value (float): Init value for Layer Scale. + + Defaults to 'S12'. + + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + in_patch_size (int): The patch size of input image patch embedding. + Defaults to 7. + in_stride (int): The stride of input image patch embedding. + Defaults to 4. + in_pad (int): The padding of input image patch embedding. + Defaults to 2. + down_patch_size (int): The patch size of downsampling patch embedding. + Defaults to 3. + down_stride (int): The stride of downsampling patch embedding. + Defaults to 2. + down_pad (int): The padding of downsampling patch embedding. + Defaults to 1. + drop_rate (float): Dropout rate. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + out_indices (Sequence | int): Output from which network position. + Index 0-6 respectively corresponds to + [stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4] + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): Initialization config dict + """ # noqa: E501 + + # --layers: [x,x,x,x], numbers of layers for the four stages + # --embed_dims, --mlp_ratios: + # embedding dims and mlp ratios for the four stages + # --downsamples: flags to apply downsampling or not in four blocks + arch_settings = { + 's12': { + 'layers': [2, 2, 6, 2], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's24': { + 'layers': [4, 4, 12, 4], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm48': { + 'layers': [8, 8, 24, 8], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + } + + def __init__(self, + arch='s12', + pool_size=3, + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + in_patch_size=7, + in_stride=4, + in_pad=2, + down_patch_size=3, + down_stride=2, + down_pad=1, + drop_rate=0., + drop_path_rate=0., + out_indices=-1, + frozen_stages=0, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'layers' in arch and 'embed_dims' in arch, \ + f'The arch dict must have "layers" and "embed_dims", ' \ + f'but got {list(arch.keys())}.' + + layers = arch['layers'] + embed_dims = arch['embed_dims'] + mlp_ratios = arch['mlp_ratios'] \ + if 'mlp_ratios' in arch else [4, 4, 4, 4] + layer_scale_init_value = arch['layer_scale_init_value'] \ + if 'layer_scale_init_value' in arch else 1e-5 + + self.patch_embed = PatchEmbed( + patch_size=in_patch_size, + stride=in_stride, + padding=in_pad, + in_chans=3, + embed_dim=embed_dims[0]) + + # set the main block in network + network = [] + for i in range(len(layers)): + stage = basic_blocks( + embed_dims[i], + i, + layers, + pool_size=pool_size, + mlp_ratio=mlp_ratios[i], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value) + network.append(stage) + if i >= len(layers) - 1: + break + if embed_dims[i] != embed_dims[i + 1]: + # downsampling between two stages + network.append( + PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + padding=down_pad, + in_chans=embed_dims[i], + embed_dim=embed_dims[i + 1])) + + self.network = nn.ModuleList(network) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 7 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + if self.out_indices: + for i_layer in self.out_indices: + layer = build_norm_layer(norm_cfg, + embed_dims[(i_layer + 1) // 2])[1] + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.frozen_stages = frozen_stages + self._freeze_stages() + + def forward_embeddings(self, x): + x = self.patch_embed(x) + return x + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + x = block(x) + if idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + x_out = norm_layer(x) + outs.append(x_out) + return tuple(outs) + + def forward(self, x): + # input embedding + x = self.forward_embeddings(x) + # through backbone + x = self.forward_tokens(x) + return x + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + # Include both block and downsample layer. + module = self.network[i] + module.eval() + for param in module.parameters(): + param.requires_grad = False + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(PoolFormer, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/regnet.py b/mmpretrain/models/backbones/regnet.py new file mode 100644 index 0000000..85dbdef --- /dev/null +++ b/mmpretrain/models/backbones/regnet.py @@ -0,0 +1,312 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import ResNet +from .resnext import Bottleneck + + +@MODELS.register_module() +class RegNet(ResNet): + """RegNet backbone. + + More details can be found in `paper `_ . + + Args: + arch (dict): The parameter of RegNets. + - w0 (int): initial width + - wa (float): slope of width + - wm (float): quantization parameter to quantize the width + - depth (int): depth of the backbone + - group_w (int): width of group + - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck. + strides (Sequence[int]): Strides of the first block of each stage. + base_channels (int): Base channels after stem layer. + in_channels (int): Number of input image channels. Default: 3. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Default: "pytorch". + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. Default: -1. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpretrain.models import RegNet + >>> import torch + >>> self = RegNet( + arch=dict( + w0=88, + wa=26.31, + wm=2.25, + group_w=48, + depth=25, + bot_mul=1.0)) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 96, 8, 8) + (1, 192, 4, 4) + (1, 432, 2, 2) + (1, 1008, 1, 1) + """ + arch_settings = { + 'regnetx_400mf': + dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0), + 'regnetx_800mf': + dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0), + 'regnetx_1.6gf': + dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0), + 'regnetx_3.2gf': + dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0), + 'regnetx_4.0gf': + dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0), + 'regnetx_6.4gf': + dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0), + 'regnetx_8.0gf': + dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0), + 'regnetx_12gf': + dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0), + } + + def __init__(self, + arch, + in_channels=3, + stem_channels=32, + base_channels=32, + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=True, + init_cfg=None): + super(ResNet, self).__init__(init_cfg) + + # Generate RegNet parameters first + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the' \ + ' arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + widths, num_stages = self.generate_regnet( + arch['w0'], + arch['wa'], + arch['wm'], + arch['depth'], + ) + # Convert to per stage format + stage_widths, stage_blocks = self.get_stages_from_blocks(widths) + # Generate group widths and bot muls + group_widths = [arch['group_w'] for _ in range(num_stages)] + self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)] + # Adjust the compatibility of stage_widths and group_widths + stage_widths, group_widths = self.adjust_width_group( + stage_widths, self.bottleneck_ratio, group_widths) + + # Group params by stage + self.stage_widths = stage_widths + self.group_widths = group_widths + self.depth = sum(stage_blocks) + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + if self.deep_stem: + raise NotImplementedError( + 'deep_stem has not been implemented for RegNet') + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.stage_blocks = stage_blocks[:num_stages] + + self._make_stem_layer(in_channels, stem_channels) + + _in_channels = stem_channels + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = self.strides[i] + dilation = self.dilations[i] + group_width = self.group_widths[i] + width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i])) + stage_groups = width // group_width + + res_layer = self.make_res_layer( + block=Bottleneck, + num_blocks=num_blocks, + in_channels=_in_channels, + out_channels=self.stage_widths[i], + expansion=1, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + base_channels=self.stage_widths[i], + groups=stage_groups, + width_per_group=group_width) + _in_channels = self.stage_widths[i] + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = stage_widths[-1] + + def _make_stem_layer(self, in_channels, base_channels): + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + def generate_regnet(self, + initial_width, + width_slope, + width_parameter, + depth, + divisor=8): + """Generates per block width from RegNet parameters. + + Args: + initial_width ([int]): Initial width of the backbone + width_slope ([float]): Slope of the quantized linear function + width_parameter ([int]): Parameter used to quantize the width. + depth ([int]): Depth of the backbone. + divisor (int): The divisor of channels. Defaults to 8. + + Returns: + tuple: tuple containing: + - list: Widths of each stage. + - int: The number of stages. + """ + assert width_slope >= 0 + assert initial_width > 0 + assert width_parameter > 1 + assert initial_width % divisor == 0 + widths_cont = np.arange(depth) * width_slope + initial_width + ks = np.round( + np.log(widths_cont / initial_width) / np.log(width_parameter)) + widths = initial_width * np.power(width_parameter, ks) + widths = np.round(np.divide(widths, divisor)) * divisor + num_stages = len(np.unique(widths)) + widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist() + return widths, num_stages + + @staticmethod + def quantize_float(number, divisor): + """Converts a float to closest non-zero int divisible by divior. + + Args: + number (int): Original number to be quantized. + divisor (int): Divisor used to quantize the number. + + Returns: + int: quantized number that is divisible by devisor. + """ + return int(round(number / divisor) * divisor) + + def adjust_width_group(self, widths, bottleneck_ratio, groups): + """Adjusts the compatibility of widths and groups. + + Args: + widths (list[int]): Width of each stage. + bottleneck_ratio (float): Bottleneck ratio. + groups (int): number of groups in each stage + + Returns: + tuple(list): The adjusted widths and groups of each stage. + """ + bottleneck_width = [ + int(w * b) for w, b in zip(widths, bottleneck_ratio) + ] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)] + bottleneck_width = [ + self.quantize_float(w_bot, g) + for w_bot, g in zip(bottleneck_width, groups) + ] + widths = [ + int(w_bot / b) + for w_bot, b in zip(bottleneck_width, bottleneck_ratio) + ] + return widths, groups + + def get_stages_from_blocks(self, widths): + """Gets widths/stage_blocks of network at each stage. + + Args: + widths (list[int]): Width in each stage. + + Returns: + tuple(list): width and depth of each stage + """ + width_diff = [ + width != width_prev + for width, width_prev in zip(widths + [0], [0] + widths) + ] + stage_widths = [ + width for width, diff in zip(widths, width_diff[:-1]) if diff + ] + stage_blocks = np.diff([ + depth for depth, diff in zip(range(len(width_diff)), width_diff) + if diff + ]).tolist() + return stage_widths, stage_blocks + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/replknet.py b/mmpretrain/models/backbones/replknet.py new file mode 100644 index 0000000..4dce415 --- /dev/null +++ b/mmpretrain/models/backbones/replknet.py @@ -0,0 +1,668 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +def conv_bn(in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + dilation=1, + norm_cfg=dict(type='BN')): + """Construct a sequential conv and bn. + + Args: + in_channels (int): Dimension of input features. + out_channels (int): Dimension of output features. + kernel_size (int): kernel_size of the convolution. + stride (int): stride of the convolution. + padding (int): stride of the convolution. + groups (int): groups of the convolution. + dilation (int): dilation of the convolution. Default to 1. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + + Returns: + nn.Sequential(): A conv layer and a batch norm layer. + """ + if padding is None: + padding = kernel_size // 2 + result = nn.Sequential() + result.add_module( + 'conv', + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False)) + result.add_module('bn', build_norm_layer(norm_cfg, out_channels)[1]) + return result + + +def conv_bn_relu(in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + dilation=1): + """Construct a sequential conv, bn and relu. + + Args: + in_channels (int): Dimension of input features. + out_channels (int): Dimension of output features. + kernel_size (int): kernel_size of the convolution. + stride (int): stride of the convolution. + padding (int): stride of the convolution. + groups (int): groups of the convolution. + dilation (int): dilation of the convolution. Default to 1. + + Returns: + nn.Sequential(): A conv layer, batch norm layer and a relu function. + """ + + if padding is None: + padding = kernel_size // 2 + result = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + dilation=dilation) + result.add_module('nonlinear', nn.ReLU()) + return result + + +def fuse_bn(conv, bn): + """Fuse the parameters in a branch with a conv and bn. + + Args: + conv (nn.Conv2d): The convolution module to fuse. + bn (nn.BatchNorm2d): The batch normalization to fuse. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The parameters obtained after + fusing the parameters of conv and bn in one branch. + The first element is the weight and the second is the bias. + """ + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class ReparamLargeKernelConv(BaseModule): + """Super large kernel implemented by with large convolutions. + + Input: Tensor with shape [B, C, H, W]. + Output: Tensor with shape [B, C, H, W]. + + Args: + in_channels (int): Dimension of input features. + out_channels (int): Dimension of output features. + kernel_size (int): kernel_size of the large convolution. + stride (int): stride of the large convolution. + groups (int): groups of the large convolution. + small_kernel (int): kernel_size of the small convolution. + small_kernel_merged (bool): Whether to switch the model structure to + deployment mode (merge the small kernel to the large kernel). + Default to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups, + small_kernel, + small_kernel_merged=False, + init_cfg=None): + super(ReparamLargeKernelConv, self).__init__(init_cfg) + self.kernel_size = kernel_size + self.small_kernel = small_kernel + self.small_kernel_merged = small_kernel_merged + # We assume the conv does not change the feature map size, + # so padding = k//2. + # Otherwise, you may configure padding as you wish, + # and change the padding of small_conv accordingly. + padding = kernel_size // 2 + if small_kernel_merged: + self.lkb_reparam = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups, + bias=True) + else: + self.lkb_origin = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups) + if small_kernel is not None: + assert small_kernel <= kernel_size + self.small_conv = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=small_kernel, + stride=stride, + padding=small_kernel // 2, + groups=groups, + dilation=1) + + def forward(self, inputs): + if hasattr(self, 'lkb_reparam'): + out = self.lkb_reparam(inputs) + else: + out = self.lkb_origin(inputs) + if hasattr(self, 'small_conv'): + out += self.small_conv(inputs) + return out + + def get_equivalent_kernel_bias(self): + eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) + if hasattr(self, 'small_conv'): + small_k, small_b = fuse_bn(self.small_conv.conv, + self.small_conv.bn) + eq_b += small_b + # add to the central part + eq_k += nn.functional.pad( + small_k, [(self.kernel_size - self.small_kernel) // 2] * 4) + return eq_k, eq_b + + def merge_kernel(self): + """Switch the model structure from training mode to deployment mode.""" + if self.small_kernel_merged: + return + eq_k, eq_b = self.get_equivalent_kernel_bias() + self.lkb_reparam = nn.Conv2d( + in_channels=self.lkb_origin.conv.in_channels, + out_channels=self.lkb_origin.conv.out_channels, + kernel_size=self.lkb_origin.conv.kernel_size, + stride=self.lkb_origin.conv.stride, + padding=self.lkb_origin.conv.padding, + dilation=self.lkb_origin.conv.dilation, + groups=self.lkb_origin.conv.groups, + bias=True) + + self.lkb_reparam.weight.data = eq_k + self.lkb_reparam.bias.data = eq_b + self.__delattr__('lkb_origin') + if hasattr(self, 'small_conv'): + self.__delattr__('small_conv') + + self.small_kernel_merged = True + + +class ConvFFN(BaseModule): + """Mlp implemented by with 1*1 convolutions. + + Input: Tensor with shape [B, C, H, W]. + Output: Tensor with shape [B, C, H, W]. + + Args: + in_channels (int): Dimension of input features. + internal_channels (int): Dimension of hidden features. + out_channels (int): Dimension of output features. + drop_path (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels, + internal_channels, + out_channels, + drop_path, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(ConvFFN, self).__init__(init_cfg) + self.drop_path = DropPath( + drop_prob=drop_path) if drop_path > 0. else nn.Identity() + self.preffn_bn = build_norm_layer(norm_cfg, in_channels)[1] + self.pw1 = conv_bn( + in_channels=in_channels, + out_channels=internal_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1) + self.pw2 = conv_bn( + in_channels=internal_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1) + self.nonlinear = build_activation_layer(act_cfg) + + def forward(self, x): + out = self.preffn_bn(x) + out = self.pw1(out) + out = self.nonlinear(out) + out = self.pw2(out) + return x + self.drop_path(out) + + +class RepLKBlock(BaseModule): + """RepLKBlock for RepLKNet backbone. + + Args: + in_channels (int): The input channels of the block. + dw_channels (int): The intermediate channels of the block, + i.e., input channels of the large kernel convolution. + block_lk_size (int): size of the super large kernel. Defaults: 31. + small_kernel (int): size of the parallel small kernel. Defaults: 5. + drop_path (float): Stochastic depth rate. Defaults: 0. + small_kernel_merged (bool): Whether to switch the model structure to + deployment mode (merge the small kernel to the large kernel). + Default to False. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + act_cfg (dict): Config dict for activation layer. + Default to ``dict(type='ReLU')``. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default to None + """ + + def __init__(self, + in_channels, + dw_channels, + block_lk_size, + small_kernel, + drop_path, + small_kernel_merged=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(RepLKBlock, self).__init__(init_cfg) + self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1) + self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1) + self.large_kernel = ReparamLargeKernelConv( + in_channels=dw_channels, + out_channels=dw_channels, + kernel_size=block_lk_size, + stride=1, + groups=dw_channels, + small_kernel=small_kernel, + small_kernel_merged=small_kernel_merged) + self.lk_nonlinear = build_activation_layer(act_cfg) + self.prelkb_bn = build_norm_layer(norm_cfg, in_channels)[1] + self.drop_path = DropPath( + drop_prob=drop_path) if drop_path > 0. else nn.Identity() + # print('drop path:', self.drop_path) + + def forward(self, x): + out = self.prelkb_bn(x) + out = self.pw1(out) + out = self.large_kernel(out) + out = self.lk_nonlinear(out) + out = self.pw2(out) + return x + self.drop_path(out) + + +class RepLKNetStage(BaseModule): + """ + generate RepLKNet blocks for a stage + return: RepLKNet blocks + + Args: + channels (int): The input channels of the stage. + num_blocks (int): The number of blocks of the stage. + stage_lk_size (int): size of the super large kernel. Defaults: 31. + drop_path (float): Stochastic depth rate. Defaults: 0. + small_kernel (int): size of the parallel small kernel. Defaults: 5. + dw_ratio (float): The intermediate channels + expansion ratio of the block. Defaults: 1. + ffn_ratio (float): Mlp expansion ratio. Defaults to 4. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default to False. + small_kernel_merged (bool): Whether to switch the model structure to + deployment mode (merge the small kernel to the large kernel). + Default to False. + norm_intermediate_features (bool): Construct and config norm layer + or not. + Using True will normalize the intermediate features for + downstream dense prediction tasks. + norm_cfg (dict): dictionary to construct and config norm layer. + Default to ``dict(type='BN', requires_grad=True)``. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default to None + """ + + def __init__( + self, + channels, + num_blocks, + stage_lk_size, + drop_path, + small_kernel, + dw_ratio=1, + ffn_ratio=4, + with_cp=False, # train with torch.utils.checkpoint to save memory + small_kernel_merged=False, + norm_intermediate_features=False, + norm_cfg=dict(type='BN'), + init_cfg=None): + super(RepLKNetStage, self).__init__(init_cfg) + self.with_cp = with_cp + blks = [] + for i in range(num_blocks): + block_drop_path = drop_path[i] if isinstance(drop_path, + list) else drop_path + # Assume all RepLK Blocks within a stage share the same lk_size. + # You may tune it on your own model. + replk_block = RepLKBlock( + in_channels=channels, + dw_channels=int(channels * dw_ratio), + block_lk_size=stage_lk_size, + small_kernel=small_kernel, + drop_path=block_drop_path, + small_kernel_merged=small_kernel_merged) + convffn_block = ConvFFN( + in_channels=channels, + internal_channels=int(channels * ffn_ratio), + out_channels=channels, + drop_path=block_drop_path) + blks.append(replk_block) + blks.append(convffn_block) + self.blocks = nn.ModuleList(blks) + if norm_intermediate_features: + self.norm = build_norm_layer(norm_cfg, channels)[1] + else: + self.norm = nn.Identity() + + def forward(self, x): + for blk in self.blocks: + if self.with_cp: + x = checkpoint.checkpoint(blk, x) # Save training memory + else: + x = blk(x) + return x + + +@MODELS.register_module() +class RepLKNet(BaseBackbone): + """RepLKNet backbone. + + A PyTorch impl of : + `Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs + `_ + + Args: + arch (str | dict): The parameter of RepLKNet. + If it's a dict, it should contain the following keys: + + - large_kernel_sizes (Sequence[int]): + Large kernel size in each stage. + - layers (Sequence[int]): Number of blocks in each stage. + - channels (Sequence[int]): Number of channels in each stage. + - small_kernel (int): size of the parallel small kernel. + - dw_ratio (float): The intermediate channels + expansion ratio of the block. + in_channels (int): Number of input image channels. Default to 3. + ffn_ratio (float): Mlp expansion ratio. Defaults to 4. + out_indices (Sequence[int]): Output from which stages. + Default to (3, ). + strides (Sequence[int]): Strides of the first block of each stage. + Default to (2, 2, 2, 2). + dilations (Sequence[int]): Dilation of each stage. + Default to (1, 1, 1, 1). + frozen_stages (int): Stages to be frozen + (all param fixed). -1 means not freezing any parameters. + Default to -1. + conv_cfg (dict | None): The config dict for conv layers. + Default to None. + norm_cfg (dict): The config dict for norm layers. + Default to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Default to ``dict(type='ReLU')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default to False. + deploy (bool): Whether to switch the model structure to deployment + mode. Default to False. + norm_intermediate_features (bool): Construct and + config norm layer or not. + Using True will normalize the intermediate features + for downstream dense prediction tasks. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + arch_settings = { + '31B': + dict( + large_kernel_sizes=[31, 29, 27, 13], + layers=[2, 2, 18, 2], + channels=[128, 256, 512, 1024], + small_kernel=5, + dw_ratio=1), + '31L': + dict( + large_kernel_sizes=[31, 29, 27, 13], + layers=[2, 2, 18, 2], + channels=[192, 384, 768, 1536], + small_kernel=5, + dw_ratio=1), + 'XL': + dict( + large_kernel_sizes=[27, 27, 27, 13], + layers=[2, 2, 18, 2], + channels=[256, 512, 1024, 2048], + small_kernel=None, + dw_ratio=1.5), + } + + def __init__(self, + arch, + in_channels=3, + ffn_ratio=4, + out_indices=(3, ), + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + drop_path_rate=0.3, + small_kernel_merged=False, + norm_intermediate_features=False, + norm_eval=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(RepLKNet, self).__init__(init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + assert len(arch['layers']) == len( + arch['channels']) == len(strides) == len(dilations) + assert max(out_indices) < len(arch['layers']) + + self.arch = arch + self.in_channels = in_channels + self.out_indices = out_indices + self.strides = strides + self.dilations = dilations + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.with_cp = with_cp + self.drop_path_rate = drop_path_rate + self.small_kernel_merged = small_kernel_merged + self.norm_eval = norm_eval + self.norm_intermediate_features = norm_intermediate_features + + self.out_indices = out_indices + + base_width = self.arch['channels'][0] + self.norm_intermediate_features = norm_intermediate_features + self.num_stages = len(self.arch['layers']) + self.stem = nn.ModuleList([ + conv_bn_relu( + in_channels=in_channels, + out_channels=base_width, + kernel_size=3, + stride=2, + padding=1, + groups=1), + conv_bn_relu( + in_channels=base_width, + out_channels=base_width, + kernel_size=3, + stride=1, + padding=1, + groups=base_width), + conv_bn_relu( + in_channels=base_width, + out_channels=base_width, + kernel_size=1, + stride=1, + padding=0, + groups=1), + conv_bn_relu( + in_channels=base_width, + out_channels=base_width, + kernel_size=3, + stride=2, + padding=1, + groups=base_width) + ]) + # stochastic depth. We set block-wise drop-path rate. + # The higher level blocks are more likely to be dropped. + # This implementation follows Swin. + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, + sum(self.arch['layers'])) + ] + self.stages = nn.ModuleList() + self.transitions = nn.ModuleList() + for stage_idx in range(self.num_stages): + layer = RepLKNetStage( + channels=self.arch['channels'][stage_idx], + num_blocks=self.arch['layers'][stage_idx], + stage_lk_size=self.arch['large_kernel_sizes'][stage_idx], + drop_path=dpr[sum(self.arch['layers'][:stage_idx] + ):sum(self.arch['layers'][:stage_idx + 1])], + small_kernel=self.arch['small_kernel'], + dw_ratio=self.arch['dw_ratio'], + ffn_ratio=ffn_ratio, + with_cp=with_cp, + small_kernel_merged=small_kernel_merged, + norm_intermediate_features=(stage_idx in out_indices)) + self.stages.append(layer) + if stage_idx < len(self.arch['layers']) - 1: + transition = nn.Sequential( + conv_bn_relu( + self.arch['channels'][stage_idx], + self.arch['channels'][stage_idx + 1], + 1, + 1, + 0, + groups=1), + conv_bn_relu( + self.arch['channels'][stage_idx + 1], + self.arch['channels'][stage_idx + 1], + 3, + stride=2, + padding=1, + groups=self.arch['channels'][stage_idx + 1])) + self.transitions.append(transition) + + def forward_features(self, x): + x = self.stem[0](x) + for stem_layer in self.stem[1:]: + if self.with_cp: + x = checkpoint.checkpoint(stem_layer, x) # save memory + else: + x = stem_layer(x) + + # Need the intermediate feature maps + outs = [] + for stage_idx in range(self.num_stages): + x = self.stages[stage_idx](x) + if stage_idx in self.out_indices: + outs.append(self.stages[stage_idx].norm(x)) + # For RepLKNet-XL normalize the features + # before feeding them into the heads + if stage_idx < self.num_stages - 1: + x = self.transitions[stage_idx](x) + return outs + + def forward(self, x): + x = self.forward_features(x) + return tuple(x) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + stage = self.stages[i] + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(RepLKNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def switch_to_deploy(self): + for m in self.modules(): + if hasattr(m, 'merge_kernel'): + m.merge_kernel() + self.small_kernel_merged = True diff --git a/mmpretrain/models/backbones/repmlp.py b/mmpretrain/models/backbones/repmlp.py new file mode 100644 index 0000000..f7c06c4 --- /dev/null +++ b/mmpretrain/models/backbones/repmlp.py @@ -0,0 +1,578 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from official impl at https://github.com/DingXiaoH/RepMLP. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmcv.cnn.bricks.transformer import PatchEmbed as _PatchEmbed +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmpretrain.models.utils import SELayer, to_2tuple +from mmpretrain.registry import MODELS + + +def fuse_bn(conv_or_fc, bn): + """fuse conv and bn.""" + std = (bn.running_var + bn.eps).sqrt() + tmp_weight = bn.weight / std + tmp_weight = tmp_weight.reshape(-1, 1, 1, 1) + + if len(tmp_weight) == conv_or_fc.weight.size(0): + return (conv_or_fc.weight * tmp_weight, + bn.bias - bn.running_mean * bn.weight / std) + else: + # in RepMLPBlock, dim0 of fc3 weights and fc3_bn weights + # are different. + repeat_times = conv_or_fc.weight.size(0) // len(tmp_weight) + repeated = tmp_weight.repeat_interleave(repeat_times, 0) + fused_weight = conv_or_fc.weight * repeated + bias = bn.bias - bn.running_mean * bn.weight / std + fused_bias = (bias).repeat_interleave(repeat_times, 0) + return (fused_weight, fused_bias) + + +class PatchEmbed(_PatchEmbed): + """Image to Patch Embedding. + + Compared with default Patch Embedding(in ViT), Patch Embedding of RepMLP + have ReLu and do not convert output tensor into shape (N, L, C). + + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (str): The type of convolution + to generate patch embedding. Default: "Conv2d". + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int): The slide stride of embedding conv. + Default: 16. + padding (int | tuple | string): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int): The dilation rate of embedding conv. Default: 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only works when `dynamic_size` + is False. Default: None. + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, *args, **kwargs): + super(PatchEmbed, self).__init__(*args, **kwargs) + self.relu = nn.ReLU() + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + Returns: + tuple: Contains merged results and its spatial shape. + - x (Tensor): The output tensor. + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adaptive_padding: + x = self.adaptive_padding(x) + + x = self.projection(x) + if self.norm is not None: + x = self.norm(x) + x = self.relu(x) + out_size = (x.shape[2], x.shape[3]) + return x, out_size + + +class GlobalPerceptron(SELayer): + """GlobalPerceptron implemented by using ``mmpretrain.modes.SELayer``. + + Args: + input_channels (int): The number of input (and output) channels + in the GlobalPerceptron. + ratio (int): Squeeze ratio in GlobalPerceptron, the intermediate + channel will be ``make_divisible(channels // ratio, divisor)``. + """ + + def __init__(self, input_channels: int, ratio: int, **kwargs) -> None: + super(GlobalPerceptron, self).__init__( + channels=input_channels, + ratio=ratio, + return_weight=True, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')), + **kwargs) + + +class RepMLPBlock(BaseModule): + """Basic RepMLPNet, consists of PartitionPerceptron and GlobalPerceptron. + + Args: + channels (int): The number of input and the output channels of the + block. + path_h (int): The height of patches. + path_w (int): The weidth of patches. + reparam_conv_kernels (Squeue(int) | None): The conv kernels in the + GlobalPerceptron. Default: None. + globalperceptron_ratio (int): The reducation ratio in the + GlobalPerceptron. Default: 4. + num_sharesets (int): The number of sharesets in the + PartitionPerceptron. Default 1. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + channels, + path_h, + path_w, + reparam_conv_kernels=None, + globalperceptron_ratio=4, + num_sharesets=1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + deploy=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.deploy = deploy + self.channels = channels + self.num_sharesets = num_sharesets + self.path_h, self.path_w = path_h, path_w + # the input channel of fc3 + self._path_vec_channles = path_h * path_w * num_sharesets + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.gp = GlobalPerceptron( + input_channels=channels, ratio=globalperceptron_ratio) + + # using a conv layer to implement a fc layer + self.fc3 = build_conv_layer( + conv_cfg, + in_channels=self._path_vec_channles, + out_channels=self._path_vec_channles, + kernel_size=1, + stride=1, + padding=0, + bias=deploy, + groups=num_sharesets) + if deploy: + self.fc3_bn = nn.Identity() + else: + norm_layer = build_norm_layer(norm_cfg, num_sharesets)[1] + self.add_module('fc3_bn', norm_layer) + + self.reparam_conv_kernels = reparam_conv_kernels + if not deploy and reparam_conv_kernels is not None: + for k in reparam_conv_kernels: + conv_branch = ConvModule( + in_channels=num_sharesets, + out_channels=num_sharesets, + kernel_size=k, + stride=1, + padding=k // 2, + norm_cfg=dict(type='BN', requires_grad=True), + groups=num_sharesets, + act_cfg=None) + self.__setattr__('repconv{}'.format(k), conv_branch) + + def partition(self, x, h_parts, w_parts): + # convert (N, C, H, W) to (N, h_parts, w_parts, C, path_h, path_w) + x = x.reshape(-1, self.channels, h_parts, self.path_h, w_parts, + self.path_w) + x = x.permute(0, 2, 4, 1, 3, 5) + return x + + def partition_affine(self, x, h_parts, w_parts): + """perform Partition Perceptron.""" + fc_inputs = x.reshape(-1, self._path_vec_channles, 1, 1) + out = self.fc3(fc_inputs) + out = out.reshape(-1, self.num_sharesets, self.path_h, self.path_w) + out = self.fc3_bn(out) + out = out.reshape(-1, h_parts, w_parts, self.num_sharesets, + self.path_h, self.path_w) + return out + + def forward(self, inputs): + # Global Perceptron + global_vec = self.gp(inputs) + + origin_shape = inputs.size() + h_parts = origin_shape[2] // self.path_h + w_parts = origin_shape[3] // self.path_w + + partitions = self.partition(inputs, h_parts, w_parts) + + # Channel Perceptron + fc3_out = self.partition_affine(partitions, h_parts, w_parts) + + # perform Local Perceptron + if self.reparam_conv_kernels is not None and not self.deploy: + conv_inputs = partitions.reshape(-1, self.num_sharesets, + self.path_h, self.path_w) + conv_out = 0 + for k in self.reparam_conv_kernels: + conv_branch = self.__getattr__('repconv{}'.format(k)) + conv_out += conv_branch(conv_inputs) + conv_out = conv_out.reshape(-1, h_parts, w_parts, + self.num_sharesets, self.path_h, + self.path_w) + fc3_out += conv_out + + # N, h_parts, w_parts, num_sharesets, out_h, out_w + fc3_out = fc3_out.permute(0, 3, 1, 4, 2, 5) + out = fc3_out.reshape(*origin_shape) + out = out * global_vec + return out + + def get_equivalent_fc3(self): + """get the equivalent fc3 weight and bias.""" + fc_weight, fc_bias = fuse_bn(self.fc3, self.fc3_bn) + if self.reparam_conv_kernels is not None: + largest_k = max(self.reparam_conv_kernels) + largest_branch = self.__getattr__('repconv{}'.format(largest_k)) + total_kernel, total_bias = fuse_bn(largest_branch.conv, + largest_branch.bn) + for k in self.reparam_conv_kernels: + if k != largest_k: + k_branch = self.__getattr__('repconv{}'.format(k)) + kernel, bias = fuse_bn(k_branch.conv, k_branch.bn) + total_kernel += F.pad(kernel, [(largest_k - k) // 2] * 4) + total_bias += bias + rep_weight, rep_bias = self._convert_conv_to_fc( + total_kernel, total_bias) + final_fc3_weight = rep_weight.reshape_as(fc_weight) + fc_weight + final_fc3_bias = rep_bias + fc_bias + else: + final_fc3_weight = fc_weight + final_fc3_bias = fc_bias + return final_fc3_weight, final_fc3_bias + + def local_inject(self): + """inject the Local Perceptron into Partition Perceptron.""" + self.deploy = True + # Locality Injection + fc3_weight, fc3_bias = self.get_equivalent_fc3() + # Remove Local Perceptron + if self.reparam_conv_kernels is not None: + for k in self.reparam_conv_kernels: + self.__delattr__('repconv{}'.format(k)) + self.__delattr__('fc3') + self.__delattr__('fc3_bn') + self.fc3 = build_conv_layer( + self.conv_cfg, + self._path_vec_channles, + self._path_vec_channles, + 1, + 1, + 0, + bias=True, + groups=self.num_sharesets) + self.fc3_bn = nn.Identity() + self.fc3.weight.data = fc3_weight + self.fc3.bias.data = fc3_bias + + def _convert_conv_to_fc(self, conv_kernel, conv_bias): + """convert conv_k1 to fc, which is still a conv_k2, and the k2 > k1.""" + in_channels = torch.eye(self.path_h * self.path_w).repeat( + 1, self.num_sharesets).reshape(self.path_h * self.path_w, + self.num_sharesets, self.path_h, + self.path_w).to(conv_kernel.device) + fc_k = F.conv2d( + in_channels, + conv_kernel, + padding=(conv_kernel.size(2) // 2, conv_kernel.size(3) // 2), + groups=self.num_sharesets) + fc_k = fc_k.reshape(self.path_w * self.path_w, self.num_sharesets * + self.path_h * self.path_w).t() + fc_bias = conv_bias.repeat_interleave(self.path_h * self.path_w) + return fc_k, fc_bias + + +class RepMLPNetUnit(BaseModule): + """A basic unit in RepMLPNet : [REPMLPBlock + BN + ConvFFN + BN]. + + Args: + channels (int): The number of input and the output channels of the + unit. + path_h (int): The height of patches. + path_w (int): The weidth of patches. + reparam_conv_kernels (Squeue(int) | None): The conv kernels in the + GlobalPerceptron. Default: None. + globalperceptron_ratio (int): The reducation ratio in the + GlobalPerceptron. Default: 4. + num_sharesets (int): The number of sharesets in the + PartitionPerceptron. Default 1. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + channels, + path_h, + path_w, + reparam_conv_kernels, + globalperceptron_ratio, + norm_cfg=dict(type='BN', requires_grad=True), + ffn_expand=4, + num_sharesets=1, + deploy=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.repmlp_block = RepMLPBlock( + channels=channels, + path_h=path_h, + path_w=path_w, + reparam_conv_kernels=reparam_conv_kernels, + globalperceptron_ratio=globalperceptron_ratio, + num_sharesets=num_sharesets, + deploy=deploy) + self.ffn_block = ConvFFN(channels, channels * ffn_expand) + norm1 = build_norm_layer(norm_cfg, channels)[1] + self.add_module('norm1', norm1) + norm2 = build_norm_layer(norm_cfg, channels)[1] + self.add_module('norm2', norm2) + + def forward(self, x): + y = x + self.repmlp_block(self.norm1(x)) + out = y + self.ffn_block(self.norm2(y)) + return out + + +class ConvFFN(nn.Module): + """ConvFFN implemented by using point-wise convs.""" + + def __init__(self, + in_channels, + hidden_channels=None, + out_channels=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='GELU')): + super().__init__() + out_features = out_channels or in_channels + hidden_features = hidden_channels or in_channels + self.ffn_fc1 = ConvModule( + in_channels=in_channels, + out_channels=hidden_features, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=norm_cfg, + act_cfg=None) + self.ffn_fc2 = ConvModule( + in_channels=hidden_features, + out_channels=out_features, + kernel_size=1, + stride=1, + padding=0, + norm_cfg=norm_cfg, + act_cfg=None) + self.act = build_activation_layer(act_cfg) + + def forward(self, x): + x = self.ffn_fc1(x) + x = self.act(x) + x = self.ffn_fc2(x) + return x + + +@MODELS.register_module() +class RepMLPNet(BaseModule): + """RepMLPNet backbone. + + A PyTorch impl of : `RepMLP: Re-parameterizing Convolutions into + Fully-connected Layers for Image Recognition + `_ + + Args: + arch (str | dict): RepMLP architecture. If use string, choose + from 'base' and 'b'. If use dict, it should have below keys: + + - channels (List[int]): Number of blocks in each stage. + - depths (List[int]): The number of blocks in each branch. + - sharesets_nums (List[int]): RepVGG Block that declares + the need to apply group convolution. + + img_size (int | tuple): The size of input image. Defaults: 224. + in_channels (int): Number of input image channels. Default: 3. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + reparam_conv_kernels (Squeue(int) | None): The conv kernels in the + GlobalPerceptron. Default: None. + globalperceptron_ratio (int): The reducation ratio in the + GlobalPerceptron. Default: 4. + num_sharesets (int): The number of sharesets in the + PartitionPerceptron. Default 1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + Default: dict(type='BN', requires_grad=True). + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + deploy (bool): Whether to switch the model structure to deployment + mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + arch_zoo = { + **dict.fromkeys(['b', 'base'], + {'channels': [96, 192, 384, 768], + 'depths': [2, 2, 12, 2], + 'sharesets_nums': [1, 4, 32, 128]}), + } # yapf: disable + + num_extra_tokens = 0 # there is no cls-token in RepMLP + + def __init__(self, + arch, + img_size=224, + in_channels=3, + patch_size=4, + out_indices=(3, ), + reparam_conv_kernels=(3, ), + globalperceptron_ratio=4, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + patch_cfg=dict(), + final_norm=True, + deploy=False, + init_cfg=None): + super(RepMLPNet, self).__init__(init_cfg=init_cfg) + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'channels', 'depths', 'sharesets_nums'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}.' + self.arch_settings = arch + + self.img_size = to_2tuple(img_size) + self.patch_size = to_2tuple(patch_size) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.num_stage = len(self.arch_settings['channels']) + for value in self.arch_settings.values(): + assert isinstance(value, list) and len(value) == self.num_stage, ( + 'Length of setting item in arch dict must be type of list and' + ' have the same length.') + + self.channels = self.arch_settings['channels'] + self.depths = self.arch_settings['depths'] + self.sharesets_nums = self.arch_settings['sharesets_nums'] + + _patch_cfg = dict( + in_channels=in_channels, + input_size=self.img_size, + embed_dims=self.channels[0], + conv_type='Conv2d', + kernel_size=self.patch_size, + stride=self.patch_size, + norm_cfg=self.norm_cfg, + bias=False) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + self.patch_hs = [ + self.patch_resolution[0] // 2**i for i in range(self.num_stage) + ] + self.patch_ws = [ + self.patch_resolution[1] // 2**i for i in range(self.num_stage) + ] + + self.stages = ModuleList() + self.downsample_layers = ModuleList() + for stage_idx in range(self.num_stage): + # make stage layers + _stage_cfg = dict( + channels=self.channels[stage_idx], + path_h=self.patch_hs[stage_idx], + path_w=self.patch_ws[stage_idx], + reparam_conv_kernels=reparam_conv_kernels, + globalperceptron_ratio=globalperceptron_ratio, + norm_cfg=self.norm_cfg, + ffn_expand=4, + num_sharesets=self.sharesets_nums[stage_idx], + deploy=deploy) + stage_blocks = [ + RepMLPNetUnit(**_stage_cfg) + for _ in range(self.depths[stage_idx]) + ] + self.stages.append(Sequential(*stage_blocks)) + + # make downsample layers + if stage_idx < self.num_stage - 1: + self.downsample_layers.append( + ConvModule( + in_channels=self.channels[stage_idx], + out_channels=self.channels[stage_idx + 1], + kernel_size=2, + stride=2, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True)) + + self.out_indice = out_indices + + if final_norm: + norm_layer = build_norm_layer(norm_cfg, self.channels[-1])[1] + else: + norm_layer = nn.Identity() + self.add_module('final_norm', norm_layer) + + def forward(self, x): + assert x.shape[2:] == self.img_size, \ + "The Rep-MLP doesn't support dynamic input shape. " \ + f'Please input images with shape {self.img_size}' + + outs = [] + + x, _ = self.patch_embed(x) + for i, stage in enumerate(self.stages): + x = stage(x) + + # downsample after each stage except last stage + if i < len(self.stages) - 1: + downsample = self.downsample_layers[i] + x = downsample(x) + + if i in self.out_indice: + if self.final_norm and i == len(self.stages) - 1: + out = self.final_norm(x) + else: + out = x + outs.append(out) + + return tuple(outs) + + def switch_to_deploy(self): + for m in self.modules(): + if hasattr(m, 'local_inject'): + m.local_inject() diff --git a/mmpretrain/models/backbones/repvgg.py b/mmpretrain/models/backbones/repvgg.py new file mode 100644 index 0000000..67c9d14 --- /dev/null +++ b/mmpretrain/models/backbones/repvgg.py @@ -0,0 +1,622 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmengine.model import BaseModule, Sequential +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm +from torch import nn + +from mmpretrain.registry import MODELS +from ..utils.se_layer import SELayer +from .base_backbone import BaseBackbone + + +class RepVGGBlock(BaseModule): + """RepVGG block for RepVGG backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + stride (int): Stride of the 3x3 and 1x1 convolution layer. Default: 1. + padding (int): Padding of the 3x3 convolution layer. + dilation (int): Dilation of the 3x3 convolution layer. + groups (int): Groups of the 3x3 and 1x1 convolution layer. Default: 1. + padding_mode (str): Padding mode of the 3x3 convolution layer. + Default: 'zeros'. + se_cfg (None or dict): The configuration of the se module. + Default: None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + padding=1, + dilation=1, + groups=1, + padding_mode='zeros', + se_cfg=None, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + deploy=False, + init_cfg=None): + super(RepVGGBlock, self).__init__(init_cfg) + + assert se_cfg is None or isinstance(se_cfg, dict) + + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.se_cfg = se_cfg + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.deploy = deploy + + if deploy: + self.branch_reparam = build_conv_layer( + conv_cfg, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=True, + padding_mode=padding_mode) + else: + # judge if input shape and output shape are the same. + # If true, add a normalized identity shortcut. + if out_channels == in_channels and stride == 1 and \ + padding == dilation: + self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1] + else: + self.branch_norm = None + + self.branch_3x3 = self.create_conv_bn( + kernel_size=3, + dilation=dilation, + padding=padding, + ) + self.branch_1x1 = self.create_conv_bn(kernel_size=1) + + if se_cfg is not None: + self.se_layer = SELayer(channels=out_channels, **se_cfg) + else: + self.se_layer = None + + self.act = build_activation_layer(act_cfg) + + def create_conv_bn(self, kernel_size, dilation=1, padding=0): + conv_bn = Sequential() + conv_bn.add_module( + 'conv', + build_conv_layer( + self.conv_cfg, + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + stride=self.stride, + dilation=dilation, + padding=padding, + groups=self.groups, + bias=False)) + conv_bn.add_module( + 'norm', + build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1]) + + return conv_bn + + def forward(self, x): + + def _inner_forward(inputs): + if self.deploy: + return self.branch_reparam(inputs) + + if self.branch_norm is None: + branch_norm_out = 0 + else: + branch_norm_out = self.branch_norm(inputs) + + inner_out = self.branch_3x3(inputs) + self.branch_1x1( + inputs) + branch_norm_out + + if self.se_cfg is not None: + inner_out = self.se_layer(inner_out) + + return inner_out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.act(out) + + return out + + def switch_to_deploy(self): + """Switch the model structure from training mode to deployment mode.""" + if self.deploy: + return + assert self.norm_cfg['type'] == 'BN', \ + "Switch is not allowed when norm_cfg['type'] != 'BN'." + + reparam_weight, reparam_bias = self.reparameterize() + self.branch_reparam = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.out_channels, + kernel_size=3, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + bias=True) + self.branch_reparam.weight.data = reparam_weight + self.branch_reparam.bias.data = reparam_bias + + for param in self.parameters(): + param.detach_() + delattr(self, 'branch_3x3') + delattr(self, 'branch_1x1') + delattr(self, 'branch_norm') + + self.deploy = True + + def reparameterize(self): + """Fuse all the parameters of all branches. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all + branches. the first element is the weights and the second is + the bias. + """ + weight_3x3, bias_3x3 = self._fuse_conv_bn(self.branch_3x3) + weight_1x1, bias_1x1 = self._fuse_conv_bn(self.branch_1x1) + # pad a conv1x1 weight to a conv3x3 weight + weight_1x1 = F.pad(weight_1x1, [1, 1, 1, 1], value=0) + + weight_norm, bias_norm = 0, 0 + if self.branch_norm: + tmp_conv_bn = self._norm_to_conv3x3(self.branch_norm) + weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn) + + return (weight_3x3 + weight_1x1 + weight_norm, + bias_3x3 + bias_1x1 + bias_norm) + + def _fuse_conv_bn(self, branch): + """Fuse the parameters in a branch with a conv and bn. + + Args: + branch (mmcv.runner.Sequential): A branch with conv and bn. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The parameters obtained after + fusing the parameters of conv and bn in one branch. + The first element is the weight and the second is the bias. + """ + if branch is None: + return 0, 0 + conv_weight = branch.conv.weight + running_mean = branch.norm.running_mean + running_var = branch.norm.running_var + gamma = branch.norm.weight + beta = branch.norm.bias + eps = branch.norm.eps + + std = (running_var + eps).sqrt() + fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * conv_weight + fused_bias = -running_mean * gamma / std + beta + + return fused_weight, fused_bias + + def _norm_to_conv3x3(self, branch_nrom): + """Convert a norm layer to a conv3x3-bn sequence. + + Args: + branch (nn.BatchNorm2d): A branch only with bn in the block. + + Returns: + tmp_conv3x3 (mmcv.runner.Sequential): a sequential with conv3x3 and + bn. + """ + input_dim = self.in_channels // self.groups + conv_weight = torch.zeros((self.in_channels, input_dim, 3, 3), + dtype=branch_nrom.weight.dtype) + + for i in range(self.in_channels): + conv_weight[i, i % input_dim, 1, 1] = 1 + conv_weight = conv_weight.to(branch_nrom.weight.device) + + tmp_conv3x3 = self.create_conv_bn(kernel_size=3) + tmp_conv3x3.conv.weight.data = conv_weight + tmp_conv3x3.norm = branch_nrom + return tmp_conv3x3 + + +class MTSPPF(BaseModule): + """MTSPPF block for YOLOX-PAI RepVGG backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of pooling. Default: 5. + """ + + def __init__(self, + in_channels, + out_channels, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + kernel_size=5): + super().__init__() + hidden_features = in_channels // 2 # hidden channels + self.conv1 = ConvModule( + in_channels, + hidden_features, + 1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = ConvModule( + hidden_features * 4, + out_channels, + 1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.maxpool = nn.MaxPool2d( + kernel_size=kernel_size, stride=1, padding=kernel_size // 2) + + def forward(self, x): + x = self.conv1(x) + y1 = self.maxpool(x) + y2 = self.maxpool(y1) + return self.conv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1)) + + +@MODELS.register_module() +class RepVGG(BaseBackbone): + """RepVGG backbone. + + A PyTorch impl of : `RepVGG: Making VGG-style ConvNets Great Again + `_ + + Args: + arch (str | dict): RepVGG architecture. If use string, choose from + 'A0', 'A1`', 'A2', 'B0', 'B1', 'B1g2', 'B1g4', 'B2', 'B2g2', + 'B2g4', 'B3', 'B3g2', 'B3g4' or 'D2se'. If use dict, it should + have below keys: + + - **num_blocks** (Sequence[int]): Number of blocks in each stage. + - **width_factor** (Sequence[float]): Width deflator in each stage. + - **group_layer_map** (dict | None): RepVGG Block that declares + the need to apply group convolution. + - **se_cfg** (dict | None): SE Layer config. + - **stem_channels** (int, optional): The stem channels, the final + stem channels will be + ``min(stem_channels, base_channels*width_factor[0])``. + If not set here, 64 is used by default in the code. + + in_channels (int): Number of input image channels. Defaults to 3. + base_channels (int): Base channels of RepVGG backbone, work with + width_factor together. Defaults to 64. + out_indices (Sequence[int]): Output from which stages. + Defaults to ``(3, )``. + strides (Sequence[int]): Strides of the first block of each stage. + Defaults to ``(2, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Defaults to ``(1, 1, 1, 1)``. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. Defaults to -1. + conv_cfg (dict | None): The config dict for conv layers. + Defaults to None. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + deploy (bool): Whether to switch the model structure to deployment + mode. Defaults to False. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + add_ppf (bool): Whether to use the MTSPPF block. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] + g2_layer_map = {layer: 2 for layer in groupwise_layers} + g4_layer_map = {layer: 4 for layer in groupwise_layers} + + arch_settings = { + 'A0': + dict( + num_blocks=[2, 4, 14, 1], + width_factor=[0.75, 0.75, 0.75, 2.5], + group_layer_map=None, + se_cfg=None), + 'A1': + dict( + num_blocks=[2, 4, 14, 1], + width_factor=[1, 1, 1, 2.5], + group_layer_map=None, + se_cfg=None), + 'A2': + dict( + num_blocks=[2, 4, 14, 1], + width_factor=[1.5, 1.5, 1.5, 2.75], + group_layer_map=None, + se_cfg=None), + 'B0': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[1, 1, 1, 2.5], + group_layer_map=None, + se_cfg=None, + stem_channels=64), + 'B1': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2, 2, 2, 4], + group_layer_map=None, + se_cfg=None), + 'B1g2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2, 2, 2, 4], + group_layer_map=g2_layer_map, + se_cfg=None), + 'B1g4': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2, 2, 2, 4], + group_layer_map=g4_layer_map, + se_cfg=None), + 'B2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=None, + se_cfg=None), + 'B2g2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=g2_layer_map, + se_cfg=None), + 'B2g4': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=g4_layer_map, + se_cfg=None), + 'B3': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[3, 3, 3, 5], + group_layer_map=None, + se_cfg=None), + 'B3g2': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[3, 3, 3, 5], + group_layer_map=g2_layer_map, + se_cfg=None), + 'B3g4': + dict( + num_blocks=[4, 6, 16, 1], + width_factor=[3, 3, 3, 5], + group_layer_map=g4_layer_map, + se_cfg=None), + 'D2se': + dict( + num_blocks=[8, 14, 24, 1], + width_factor=[2.5, 2.5, 2.5, 5], + group_layer_map=None, + se_cfg=dict(ratio=16, divisor=1)), + 'yolox-pai-small': + dict( + num_blocks=[3, 5, 7, 3], + width_factor=[1, 1, 1, 1], + group_layer_map=None, + se_cfg=None, + stem_channels=32), + } + + def __init__(self, + arch, + in_channels=3, + base_channels=64, + out_indices=(3, ), + strides=(2, 2, 2, 2), + dilations=(1, 1, 1, 1), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + deploy=False, + norm_eval=False, + add_ppf=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(RepVGG, self).__init__(init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'"arch": "{arch}" is not one of the arch_settings' + arch = self.arch_settings[arch] + elif not isinstance(arch, dict): + raise TypeError('Expect "arch" to be either a string ' + f'or a dict, got {type(arch)}') + + assert len(arch['num_blocks']) == len( + arch['width_factor']) == len(strides) == len(dilations) + assert max(out_indices) < len(arch['num_blocks']) + if arch['group_layer_map'] is not None: + assert max(arch['group_layer_map'].keys()) <= sum( + arch['num_blocks']) + + if arch['se_cfg'] is not None: + assert isinstance(arch['se_cfg'], dict) + + self.base_channels = base_channels + self.arch = arch + self.in_channels = in_channels + self.out_indices = out_indices + self.strides = strides + self.dilations = dilations + self.deploy = deploy + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + + # defaults to 64 to prevert BC-breaking if stem_channels + # not in arch dict; + # the stem channels should not be larger than that of stage1. + channels = min( + arch.get('stem_channels', 64), + int(self.base_channels * self.arch['width_factor'][0])) + self.stem = RepVGGBlock( + self.in_channels, + channels, + stride=2, + se_cfg=arch['se_cfg'], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + deploy=deploy) + + next_create_block_idx = 1 + self.stages = [] + for i in range(len(arch['num_blocks'])): + num_blocks = self.arch['num_blocks'][i] + stride = self.strides[i] + dilation = self.dilations[i] + out_channels = int(self.base_channels * 2**i * + self.arch['width_factor'][i]) + + stage, next_create_block_idx = self._make_stage( + channels, out_channels, num_blocks, stride, dilation, + next_create_block_idx, init_cfg) + stage_name = f'stage_{i + 1}' + self.add_module(stage_name, stage) + self.stages.append(stage_name) + + channels = out_channels + + if add_ppf: + self.ppf = MTSPPF( + out_channels, + out_channels, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + kernel_size=5) + else: + self.ppf = nn.Identity() + + def _make_stage(self, in_channels, out_channels, num_blocks, stride, + dilation, next_create_block_idx, init_cfg): + strides = [stride] + [1] * (num_blocks - 1) + dilations = [dilation] * num_blocks + + blocks = [] + for i in range(num_blocks): + groups = self.arch['group_layer_map'].get( + next_create_block_idx, + 1) if self.arch['group_layer_map'] is not None else 1 + blocks.append( + RepVGGBlock( + in_channels, + out_channels, + stride=strides[i], + padding=dilations[i], + dilation=dilations[i], + groups=groups, + se_cfg=self.arch['se_cfg'], + with_cp=self.with_cp, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + deploy=self.deploy, + init_cfg=init_cfg)) + in_channels = out_channels + next_create_block_idx += 1 + + return Sequential(*blocks), next_create_block_idx + + def forward(self, x): + x = self.stem(x) + outs = [] + for i, stage_name in enumerate(self.stages): + stage = getattr(self, stage_name) + x = stage(x) + if i + 1 == len(self.stages): + x = self.ppf(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + stage = getattr(self, f'stage_{i+1}') + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(RepVGG, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def switch_to_deploy(self): + for m in self.modules(): + if isinstance(m, RepVGGBlock): + m.switch_to_deploy() + self.deploy = True diff --git a/mmpretrain/models/backbones/res2net.py b/mmpretrain/models/backbones/res2net.py new file mode 100644 index 0000000..6e9bb6d --- /dev/null +++ b/mmpretrain/models/backbones/res2net.py @@ -0,0 +1,317 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import ModuleList, Sequential + +from mmpretrain.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottle2neck(_Bottleneck): + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + scales=4, + base_width=26, + base_channels=64, + stage_type='normal', + **kwargs): + """Bottle2neck block for Res2Net.""" + super(Bottle2neck, self).__init__(in_channels, out_channels, **kwargs) + assert scales > 1, 'Res2Net degenerates to ResNet when scales = 1.' + + mid_channels = out_channels // self.expansion + width = int(math.floor(mid_channels * (base_width / base_channels))) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width * scales, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + width * scales, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + + if stage_type == 'stage': + self.pool = nn.AvgPool2d( + kernel_size=3, stride=self.conv2_stride, padding=1) + + self.convs = ModuleList() + self.bns = ModuleList() + for i in range(scales - 1): + self.convs.append( + build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + bias=False)) + self.bns.append( + build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1]) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width * scales, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.stage_type = stage_type + self.scales = scales + self.width = width + delattr(self, 'conv2') + delattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + spx = torch.split(out, self.width, 1) + sp = self.convs[0](spx[0].contiguous()) + sp = self.relu(self.bns[0](sp)) + out = sp + for i in range(1, self.scales - 1): + if self.stage_type == 'stage': + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp.contiguous()) + sp = self.relu(self.bns[i](sp)) + out = torch.cat((out, sp), 1) + + if self.stage_type == 'normal' and self.scales != 1: + out = torch.cat((out, spx[self.scales - 1]), 1) + elif self.stage_type == 'stage' and self.scales != 1: + out = torch.cat((out, self.pool(spx[self.scales - 1])), 1) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Res2Layer(Sequential): + """Res2Layer to build Res2Net style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottle2neck. Defaults to True. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + scales (int): Scales used in Res2Net. Default: 4 + base_width (int): Basic width of each scale. Default: 26 + drop_path_rate (float or np.ndarray): stochastic depth rate. + Default: 0. + """ + + def __init__(self, + block, + in_channels, + out_channels, + num_blocks, + stride=1, + avg_down=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + scales=4, + base_width=26, + drop_path_rate=0.0, + **kwargs): + self.block = block + + if isinstance(drop_path_rate, float): + drop_path_rate = [drop_path_rate] * num_blocks + + assert len(drop_path_rate + ) == num_blocks, 'Please check the length of drop_path_rate' + + downsample = None + if stride != 1 or in_channels != out_channels: + if avg_down: + downsample = nn.Sequential( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False), + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1], + ) + else: + downsample = nn.Sequential( + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1], + ) + + layers = [] + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + scales=scales, + base_width=base_width, + stage_type='stage', + drop_path_rate=drop_path_rate[0], + **kwargs)) + in_channels = out_channels + for i in range(1, num_blocks): + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + scales=scales, + base_width=base_width, + drop_path_rate=drop_path_rate[i], + **kwargs)) + super(Res2Layer, self).__init__(*layers) + + +@MODELS.register_module() +class Res2Net(ResNet): + """Res2Net backbone. + + A PyTorch implement of : `Res2Net: A New Multi-scale Backbone + Architecture `_ + + Args: + depth (int): Depth of Res2Net, choose from {50, 101, 152}. + scales (int): Scales used in Res2Net. Defaults to 4. + base_width (int): Basic width of each scale. Defaults to 26. + in_channels (int): Number of input image channels. Defaults to 3. + num_stages (int): Number of Res2Net stages. Defaults to 4. + strides (Sequence[int]): Strides of the first block of each stage. + Defaults to ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Defaults to ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. + Defaults to ``(3, )``. + style (str): "pytorch" or "caffe". If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Defaults to "pytorch". + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Defaults to True. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottle2neck. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to ``dict(type='BN', requires_grad=True)``. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Defaults to True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + + Example: + >>> from mmpretrain.models import Res2Net + >>> import torch + >>> model = Res2Net(depth=50, + ... scales=4, + ... base_width=26, + ... out_indices=(0, 1, 2, 3)) + >>> model.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = model.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottle2neck, (3, 4, 6, 3)), + 101: (Bottle2neck, (3, 4, 23, 3)), + 152: (Bottle2neck, (3, 8, 36, 3)) + } + + def __init__(self, + scales=4, + base_width=26, + style='pytorch', + deep_stem=True, + avg_down=True, + init_cfg=None, + **kwargs): + self.scales = scales + self.base_width = base_width + super(Res2Net, self).__init__( + style=style, + deep_stem=deep_stem, + avg_down=avg_down, + init_cfg=init_cfg, + **kwargs) + + def make_res_layer(self, **kwargs): + return Res2Layer( + scales=self.scales, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/mmpretrain/models/backbones/resnest.py b/mmpretrain/models/backbones/resnest.py new file mode 100644 index 0000000..4bb438f --- /dev/null +++ b/mmpretrain/models/backbones/resnest.py @@ -0,0 +1,339 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResLayer, ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(nn.Module): + """Split-Attention Conv2d. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int | tuple[int]): Same as nn.Conv2d. + stride (int | tuple[int]): Same as nn.Conv2d. + padding (int | tuple[int]): Same as nn.Conv2d. + dilation (int | tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN')): + super(SplitAttentionConv2d, self).__init__() + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + return getattr(self, self.norm0_name) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + groups=1, + width_per_group=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs) + + self.groups = groups + self.width_per_group = width_per_group + + # For ResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for ResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = SplitAttentionConv2d( + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152, 200}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of SplitAttentionConv2d. + Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)), + 269: (Bottleneck, (3, 30, 48, 8)) + } + + def __init__(self, + depth, + groups=1, + width_per_group=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.width_per_group = width_per_group + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super(ResNeSt, self).__init__(depth=depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/mmpretrain/models/backbones/resnet.py b/mmpretrain/models/backbones/resnet.py new file mode 100644 index 0000000..4a254f7 --- /dev/null +++ b/mmpretrain/models/backbones/resnet.py @@ -0,0 +1,768 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, + build_norm_layer) +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.model.weight_init import constant_init +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + +eps = 1.0e-5 + + +class BasicBlock(BaseModule): + """BasicBlock for ResNet. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the output channels of conv1. This is a + reserved argument in BasicBlock and should always be 1. Default: 1. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None. + style (str): `pytorch` or `caffe`. It is unused and reserved for + unified API with Bottleneck. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, + in_channels, + out_channels, + expansion=1, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + drop_path_rate=0.0, + act_cfg=dict(type='ReLU', inplace=True), + init_cfg=None): + super(BasicBlock, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.expansion = expansion + assert self.expansion == 1 + assert out_channels % expansion == 0 + self.mid_channels = out_channels // expansion + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, out_channels, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + in_channels, + self.mid_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, + self.mid_channels, + out_channels, + 3, + padding=1, + bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = build_activation_layer(act_cfg) + self.downsample = downsample + self.drop_path = DropPath(drop_prob=drop_path_rate + ) if drop_path_rate > eps else nn.Identity() + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.drop_path(out) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + """Bottleneck block for ResNet. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int): The ratio of ``out_channels/mid_channels`` where + ``mid_channels`` is the input/output channels of conv2. Default: 4. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None. + style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the + stride-two layer is the 3x3 conv layer, otherwise the stride-two + layer is the first 1x1 conv layer. Default: "pytorch". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, + in_channels, + out_channels, + expansion=4, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU', inplace=True), + drop_path_rate=0.0, + init_cfg=None): + super(Bottleneck, self).__init__(init_cfg=init_cfg) + assert style in ['pytorch', 'caffe'] + + self.in_channels = in_channels + self.out_channels = out_channels + self.expansion = expansion + assert out_channels % expansion == 0 + self.mid_channels = out_channels // expansion + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, out_channels, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + self.mid_channels, + out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = build_activation_layer(act_cfg) + self.downsample = downsample + self.drop_path = DropPath(drop_prob=drop_path_rate + ) if drop_path_rate > eps else nn.Identity() + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + @property + def norm3(self): + return getattr(self, self.norm3_name) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.drop_path(out) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +def get_expansion(block, expansion=None): + """Get the expansion of a residual block. + + The block expansion will be obtained by the following order: + + 1. If ``expansion`` is given, just return it. + 2. If ``block`` has the attribute ``expansion``, then return + ``block.expansion``. + 3. Return the default value according the the block type: + 1 for ``BasicBlock`` and 4 for ``Bottleneck``. + + Args: + block (class): The block class. + expansion (int | None): The given expansion ratio. + + Returns: + int: The expansion of the block. + """ + if isinstance(expansion, int): + assert expansion > 0 + elif expansion is None: + if hasattr(block, 'expansion'): + expansion = block.expansion + elif issubclass(block, BasicBlock): + expansion = 1 + elif issubclass(block, Bottleneck): + expansion = 4 + else: + raise TypeError(f'expansion is not specified for {block.__name__}') + else: + raise TypeError('expansion must be an integer or None') + + return expansion + + +class ResLayer(nn.Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): Residual block used to build ResLayer. + num_blocks (int): Number of blocks. + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + expansion (int, optional): The expansion for BasicBlock/Bottleneck. + If not specified, it will firstly be obtained via + ``block.expansion``. If the block has no attribute "expansion", + the following default values will be used: 1 for BasicBlock and + 4 for Bottleneck. Default: None. + stride (int): stride of the first block. Default: 1. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + drop_path_rate (float or list): stochastic depth rate. + Default: 0. + """ + + def __init__(self, + block, + num_blocks, + in_channels, + out_channels, + expansion=None, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + drop_path_rate=0.0, + **kwargs): + self.block = block + self.expansion = get_expansion(block, expansion) + + if isinstance(drop_path_rate, float): + drop_path_rate = [drop_path_rate] * num_blocks + + assert len(drop_path_rate + ) == num_blocks, 'Please check the length of drop_path_rate' + + downsample = None + if stride != 1 or in_channels != out_channels: + downsample = [] + conv_stride = stride + if avg_down and stride != 1: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, out_channels)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + drop_path_rate=drop_path_rate[0], + **kwargs)) + in_channels = out_channels + for i in range(1, num_blocks): + layers.append( + block( + in_channels=in_channels, + out_channels=out_channels, + expansion=self.expansion, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + drop_path_rate=drop_path_rate[i], + **kwargs)) + super(ResLayer, self).__init__(*layers) + + +@MODELS.register_module() +class ResNet(BaseBackbone): + """ResNet backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + base_channels (int): Middle channels of the first stage. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpretrain.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=64, + base_channels=64, + expansion=None, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + zero_init_residual=True, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ], + drop_path_rate=0.0): + super(ResNet, self).__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.depth = depth + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.expansion = get_expansion(self.block, expansion) + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + _in_channels = stem_channels + _out_channels = base_channels * self.expansion + + # stochastic depth decay rule + total_depth = sum(stage_blocks) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] + + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + res_layer = self.make_res_layer( + block=self.block, + num_blocks=num_blocks, + in_channels=_in_channels, + out_channels=_out_channels, + expansion=self.expansion, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + drop_path_rate=dpr[:num_blocks]) + _in_channels = _out_channels + _out_channels *= 2 + dpr = dpr[num_blocks:] + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = res_layer[-1].out_channels + + def make_res_layer(self, **kwargs): + return ResLayer(**kwargs) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + if self.deep_stem: + self.stem = nn.Sequential( + ConvModule( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True), + ConvModule( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + super(ResNet, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress zero_init_residual if use pretrained model. + return + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + + def forward(self, x): + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer id to set the different learning rates for ResNet. + + ResNet stages: + 50 : [3, 4, 6, 3] + 101 : [3, 4, 23, 3] + 152 : [3, 8, 36, 3] + 200 : [3, 24, 36, 3] + eca269d: [3, 30, 48, 8] + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + """ + depths = self.stage_blocks + if depths[1] == 4 and depths[2] == 6: + blk2, blk3 = 2, 3 + elif depths[1] == 4 and depths[2] == 23: + blk2, blk3 = 2, 3 + elif depths[1] == 8 and depths[2] == 36: + blk2, blk3 = 4, 4 + elif depths[1] == 24 and depths[2] == 36: + blk2, blk3 = 4, 4 + elif depths[1] == 30 and depths[2] == 48: + blk2, blk3 = 5, 6 + else: + raise NotImplementedError + + N2, N3 = math.ceil(depths[1] / blk2 - + 1e-5), math.ceil(depths[2] / blk3 - 1e-5) + N = 2 + N2 + N3 # r50: 2 + 2 + 2 = 6 + max_layer_id = N + 1 # r50: 2 + 2 + 2 + 1(like head) = 7 + + if not param_name.startswith(prefix): + # For subsequent module like head + return max_layer_id, max_layer_id + 1 + + if param_name.startswith('backbone.layer'): + stage_id = int(param_name.split('.')[1][5:]) + block_id = int(param_name.split('.')[2]) + + if stage_id == 1: + layer_id = 1 + elif stage_id == 2: + layer_id = 2 + block_id // blk2 # r50: 2, 3 + elif stage_id == 3: + layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 + else: # stage_id == 4 + layer_id = N # r50: 6 + return layer_id, max_layer_id + 1 + + else: + return 0, max_layer_id + 1 + + +@MODELS.register_module() +class ResNetV1c(ResNet): + """ResNetV1c backbone. + + This variant is described in `Bag of Tricks. + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv + in the input stem with three 3x3 convs. + """ + + def __init__(self, **kwargs): + super(ResNetV1c, self).__init__( + deep_stem=True, avg_down=False, **kwargs) + + +@MODELS.register_module() +class ResNetV1d(ResNet): + """ResNetV1d backbone. + + This variant is described in `Bag of Tricks. + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) diff --git a/mmpretrain/models/backbones/resnet_cifar.py b/mmpretrain/models/backbones/resnet_cifar.py new file mode 100644 index 0000000..9f17f92 --- /dev/null +++ b/mmpretrain/models/backbones/resnet_cifar.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import ResNet + + +@MODELS.register_module() +class ResNet_CIFAR(ResNet): + """ResNet backbone for CIFAR. + + Compared to standard ResNet, it uses `kernel_size=3` and `stride=1` in + conv1, and does not apply MaxPoolinng after stem. It has been proven to + be more efficient than standard ResNet in other public codebase, e.g., + `https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py`. + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + base_channels (int): Middle channels of the first stage. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): This network has specific designed stem, thus it is + asserted to be False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + def __init__(self, depth, deep_stem=False, **kwargs): + super(ResNet_CIFAR, self).__init__( + depth, deep_stem=deep_stem, **kwargs) + assert not self.deep_stem, 'ResNet_CIFAR do not support deep_stem' + + def _make_stem_layer(self, in_channels, base_channels): + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) diff --git a/mmpretrain/models/backbones/resnext.py b/mmpretrain/models/backbones/resnext.py new file mode 100644 index 0000000..8858b7d --- /dev/null +++ b/mmpretrain/models/backbones/resnext.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResLayer, ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeXt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + base_channels=64, + groups=32, + width_per_group=4, + **kwargs): + super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs) + self.groups = groups + self.width_per_group = width_per_group + + # For ResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for ResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, groups=32, width_per_group=4, **kwargs): + self.groups = groups + self.width_per_group = width_per_group + super(ResNeXt, self).__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + **kwargs) diff --git a/mmpretrain/models/backbones/revvit.py b/mmpretrain/models/backbones/revvit.py new file mode 100644 index 0000000..f2e6c28 --- /dev/null +++ b/mmpretrain/models/backbones/revvit.py @@ -0,0 +1,671 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys + +import numpy as np +import torch +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from torch import nn +from torch.autograd import Function as Function + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed, + to_2tuple) + + +class RevBackProp(Function): + """Custom Backpropagation function to allow (A) flushing memory in forward + and (B) activation recomputation reversibly in backward for gradient + calculation. + + Inspired by + https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py + """ + + @staticmethod + def forward( + ctx, + x, + layers, + buffer_layers, # List of layer ids for int activation to buffer + ): + """Reversible Forward pass. + + Any intermediate activations from `buffer_layers` are cached in ctx for + forward pass. This is not necessary for standard usecases. Each + reversible layer implements its own forward pass logic. + """ + buffer_layers.sort() + x1, x2 = torch.chunk(x, 2, dim=-1) + intermediate = [] + + for layer in layers: + x1, x2 = layer(x1, x2) + if layer.layer_id in buffer_layers: + intermediate.extend([x1.detach(), x2.detach()]) + + if len(buffer_layers) == 0: + all_tensors = [x1.detach(), x2.detach()] + else: + intermediate = [torch.LongTensor(buffer_layers), *intermediate] + all_tensors = [x1.detach(), x2.detach(), *intermediate] + + ctx.save_for_backward(*all_tensors) + ctx.layers = layers + + return torch.cat([x1, x2], dim=-1) + + @staticmethod + def backward(ctx, dx): + """Reversible Backward pass. + + Any intermediate activations from `buffer_layers` are recovered from + ctx. Each layer implements its own loic for backward pass (both + activation recomputation and grad calculation). + """ + d_x1, d_x2 = torch.chunk(dx, 2, dim=-1) + # retrieve params from ctx for backward + x1, x2, *int_tensors = ctx.saved_tensors + # no buffering + if len(int_tensors) != 0: + buffer_layers = int_tensors[0].tolist() + else: + buffer_layers = [] + + layers = ctx.layers + + for _, layer in enumerate(layers[::-1]): + if layer.layer_id in buffer_layers: + x1, x2, d_x1, d_x2 = layer.backward_pass( + y1=int_tensors[buffer_layers.index(layer.layer_id) * 2 + + 1], + y2=int_tensors[buffer_layers.index(layer.layer_id) * 2 + + 2], + d_y1=d_x1, + d_y2=d_x2, + ) + else: + x1, x2, d_x1, d_x2 = layer.backward_pass( + y1=x1, + y2=x2, + d_y1=d_x1, + d_y2=d_x2, + ) + + dx = torch.cat([d_x1, d_x2], dim=-1) + + del int_tensors + del d_x1, d_x2, x1, x2 + + return dx, None, None + + +class RevTransformerEncoderLayer(BaseModule): + """Reversible Transformer Encoder Layer. + + This module is a building block of Reversible Transformer Encoder, + which support backpropagation without storing activations. + The residual connection is not applied to the FFN layer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed. + Default: 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0 + drop_path_rate (float): stochastic depth rate. + Default 0.0 + num_fcs (int): The number of linear in FFN + Default: 2 + qkv_bias (bool): enable bias for qkv if True. + Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU') + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + layer_id (int): The layer id of current layer. Used in RevBackProp. + Default: 0 + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + qkv_bias: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + layer_id: int = 0, + init_cfg=None): + super(RevTransformerEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.drop_path_cfg = dict(type='DropPath', drop_prob=drop_path_rate) + self.embed_dims = embed_dims + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + qkv_bias=qkv_bias) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + act_cfg=act_cfg, + add_identity=False) + + self.layer_id = layer_id + self.seeds = {} + + def init_weights(self): + super(RevTransformerEncoderLayer, self).init_weights() + for m in self.ffn.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def seed_cuda(self, key): + """Fix seeds to allow for stochastic elements such as dropout to be + reproduced exactly in activation recomputation in the backward pass.""" + # randomize seeds + # use cuda generator if available + if (hasattr(torch.cuda, 'default_generators') + and len(torch.cuda.default_generators) > 0): + # GPU + device_idx = torch.cuda.current_device() + seed = torch.cuda.default_generators[device_idx].seed() + else: + # CPU + seed = int(torch.seed() % sys.maxsize) + + self.seeds[key] = seed + torch.manual_seed(self.seeds[key]) + + def forward(self, x1, x2): + """ + Implementation of Reversible TransformerEncoderLayer + + ` + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) + ` + """ + self.seed_cuda('attn') + # attention output + f_x2 = self.attn(self.ln1(x2)) + # apply droppath on attention output + self.seed_cuda('droppath') + f_x2_dropped = build_dropout(self.drop_path_cfg)(f_x2) + y1 = x1 + f_x2_dropped + + # free memory + if self.training: + del x1 + + # ffn output + self.seed_cuda('ffn') + g_y1 = self.ffn(self.ln2(y1)) + # apply droppath on ffn output + torch.manual_seed(self.seeds['droppath']) + g_y1_dropped = build_dropout(self.drop_path_cfg)(g_y1) + # final output + y2 = x2 + g_y1_dropped + + # free memory + if self.training: + del x2 + + return y1, y2 + + def backward_pass(self, y1, y2, d_y1, d_y2): + """Activation re-compute with the following equation. + + x2 = y2 - g(y1), g = FFN + x1 = y1 - f(x2), f = MSHA + """ + + # temporarily record intermediate activation for G + # and use them for gradient calculation of G + with torch.enable_grad(): + y1.requires_grad = True + + torch.manual_seed(self.seeds['ffn']) + g_y1 = self.ffn(self.ln2(y1)) + + torch.manual_seed(self.seeds['droppath']) + g_y1 = build_dropout(self.drop_path_cfg)(g_y1) + + g_y1.backward(d_y2, retain_graph=True) + + # activate recomputation is by design and not part of + # the computation graph in forward pass + with torch.no_grad(): + x2 = y2 - g_y1 + del g_y1 + + d_y1 = d_y1 + y1.grad + y1.grad = None + + # record F activation and calculate gradients on F + with torch.enable_grad(): + x2.requires_grad = True + + torch.manual_seed(self.seeds['attn']) + f_x2 = self.attn(self.ln1(x2)) + + torch.manual_seed(self.seeds['droppath']) + f_x2 = build_dropout(self.drop_path_cfg)(f_x2) + + f_x2.backward(d_y1, retain_graph=True) + + # propagate reverse computed activations at the + # start of the previous block + with torch.no_grad(): + x1 = y1 - f_x2 + del f_x2, y1 + + d_y2 = d_y2 + x2.grad + + x2.grad = None + x2 = x2.detach() + + return x1, x2, d_y1, d_y2 + + +class TwoStreamFusion(nn.Module): + """A general constructor for neural modules fusing two equal sized tensors + in forward. + + Args: + mode (str): The mode of fusion. Options are 'add', 'max', 'min', + 'avg', 'concat'. + """ + + def __init__(self, mode: str): + super().__init__() + self.mode = mode + + if mode == 'add': + self.fuse_fn = lambda x: torch.stack(x).sum(dim=0) + elif mode == 'max': + self.fuse_fn = lambda x: torch.stack(x).max(dim=0).values + elif mode == 'min': + self.fuse_fn = lambda x: torch.stack(x).min(dim=0).values + elif mode == 'avg': + self.fuse_fn = lambda x: torch.stack(x).mean(dim=0) + elif mode == 'concat': + self.fuse_fn = lambda x: torch.cat(x, dim=-1) + else: + raise NotImplementedError + + def forward(self, x): + # split the tensor into two halves in the channel dimension + x = torch.chunk(x, 2, dim=2) + return self.fuse_fn(x) + + +@MODELS.register_module() +class RevVisionTransformer(BaseBackbone): + """Reversible Vision Transformer. + + A PyTorch implementation of : `Reversible Vision Transformers + `_ # noqa: E501 + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"avg_featmap"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + fusion_mode (str): The fusion mode of transformer layers. + Defaults to 'concat'. + no_custom_backward (bool): Whether to use custom backward. + Defaults to False. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 768, + 'num_layers': 8, + 'num_heads': 8, + 'feedforward_channels': 768 * 3, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['h', 'huge'], + { + # The same as the implementation in MAE + # + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120 + }), + **dict.fromkeys( + ['deit-t', 'deit-tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': 192 * 4 + }), + **dict.fromkeys( + ['deit-s', 'deit-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 384 * 4 + }), + **dict.fromkeys( + ['deit-b', 'deit-base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 768 * 4 + }), + } + num_extra_tokens = 0 # The official RevViT doesn't have class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='avg_featmap', + with_cls_token=False, + frozen_stages=-1, + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + fusion_mode='concat', + no_custom_backward=False, + init_cfg=None): + super(RevVisionTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + self.no_custom_backward = no_custom_backward + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_extra_tokens, + self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + layer_id=i, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(RevTransformerEncoderLayer(**_layer_cfg)) + + # fusion operation for the final output + self.fusion_layer = TwoStreamFusion(mode=fusion_mode) + + self.frozen_stages = frozen_stages + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims * 2) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + def init_weights(self): + super(RevVisionTransformer, self).init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + trunc_normal_(self.pos_embed, std=0.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + @staticmethod + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) + + def _freeze_stages(self): + # freeze position embedding + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze cls_token + if self.cls_token is not None: + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers) and self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = torch.cat([x, x], dim=-1) + + # forward with different conditions + if not self.training or self.no_custom_backward: + # in eval/inference model + executing_fn = RevVisionTransformer._forward_vanilla_bp + else: + # use custom backward when self.training=True. + executing_fn = RevBackProp.apply + + x = executing_fn(x, self.layers, []) + + if self.final_norm: + x = self.ln1(x) + x = self.fusion_layer(x) + + return (self._format_output(x, patch_resolution), ) + + @staticmethod + def _forward_vanilla_bp(hidden_state, layers, buffer=[]): + """Using reversible layers without reversible backpropagation. + + Debugging purpose only. Activated with self.no_custom_backward + """ + # split into ffn state(ffn_out) and attention output(attn_out) + ffn_out, attn_out = torch.chunk(hidden_state, 2, dim=-1) + del hidden_state + + for _, layer in enumerate(layers): + attn_out, ffn_out = layer(attn_out, ffn_out) + + return torch.cat([attn_out, ffn_out], dim=-1) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) diff --git a/mmpretrain/models/backbones/riformer.py b/mmpretrain/models/backbones/riformer.py new file mode 100644 index 0000000..ad7cb4d --- /dev/null +++ b/mmpretrain/models/backbones/riformer.py @@ -0,0 +1,390 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath, build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone +from .poolformer import Mlp, PatchEmbed + + +class Affine(nn.Module): + """Affine Transformation module. + + Args: + in_features (int): Input dimension. + """ + + def __init__(self, in_features): + super().__init__() + self.affine = nn.Conv2d( + in_features, + in_features, + kernel_size=1, + stride=1, + padding=0, + groups=in_features, + bias=True) + + def forward(self, x): + return self.affine(x) - x + + +class RIFormerBlock(BaseModule): + """RIFormer Block. + + Args: + dim (int): Embedding dim. + mlp_ratio (float): Mlp expansion ratio. Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='GN', num_groups=1)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + drop (float): Dropout rate. Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-5. + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + """ + + def __init__(self, + dim, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop=0., + drop_path=0., + layer_scale_init_value=1e-5, + deploy=False): + + super().__init__() + + if deploy: + self.norm_reparam = build_norm_layer(norm_cfg, dim)[1] + else: + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + self.token_mixer = Affine(in_features=dim) + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + # The following two techniques are useful to train deep RIFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.norm_cfg = norm_cfg + self.dim = dim + self.deploy = deploy + + def forward(self, x): + if hasattr(self, 'norm_reparam'): + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.norm_reparam(x)) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + else: + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.token_mixer(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + return x + + def fuse_affine(self, norm, token_mixer): + gamma_affn = token_mixer.affine.weight.reshape(-1) + gamma_affn = gamma_affn - torch.ones_like(gamma_affn) + beta_affn = token_mixer.affine.bias + gamma_ln = norm.weight + beta_ln = norm.bias + return (gamma_ln * gamma_affn), (beta_ln * gamma_affn + beta_affn) + + def get_equivalent_scale_bias(self): + eq_s, eq_b = self.fuse_affine(self.norm1, self.token_mixer) + return eq_s, eq_b + + def switch_to_deploy(self): + if self.deploy: + return + eq_s, eq_b = self.get_equivalent_scale_bias() + self.norm_reparam = build_norm_layer(self.norm_cfg, self.dim)[1] + self.norm_reparam.weight.data = eq_s + self.norm_reparam.bias.data = eq_b + self.__delattr__('norm1') + if hasattr(self, 'token_mixer'): + self.__delattr__('token_mixer') + self.deploy = True + + +def basic_blocks(dim, + index, + layers, + mlp_ratio=4., + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + drop_rate=.0, + drop_path_rate=0., + layer_scale_init_value=1e-5, + deploy=False): + """generate RIFormer blocks for a stage.""" + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / ( + sum(layers) - 1) + blocks.append( + RIFormerBlock( + dim, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop=drop_rate, + drop_path=block_dpr, + layer_scale_init_value=layer_scale_init_value, + deploy=deploy, + )) + blocks = nn.Sequential(*blocks) + + return blocks + + +@MODELS.register_module() +class RIFormer(BaseBackbone): + """RIFormer. + + A PyTorch implementation of RIFormer introduced by: + `RIFormer: Keep Your Vision Backbone Effective But Removing Token Mixer `_ + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``RIFormer.arch_settings``. And if dict, it + should include the following two keys: + + - layers (list[int]): Number of blocks at each stage. + - embed_dims (list[int]): The number of channels at each stage. + - mlp_ratios (list[int]): Expansion ratio of MLPs. + - layer_scale_init_value (float): Init value for Layer Scale. + + Defaults to 'S12'. + + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='LN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + in_patch_size (int): The patch size of/? input image patch embedding. + Defaults to 7. + in_stride (int): The stride of input image patch embedding. + Defaults to 4. + in_pad (int): The padding of input image patch embedding. + Defaults to 2. + down_patch_size (int): The patch size of downsampling patch embedding. + Defaults to 3. + down_stride (int): The stride of downsampling patch embedding. + Defaults to 2. + down_pad (int): The padding of downsampling patch embedding. + Defaults to 1. + drop_rate (float): Dropout rate. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + out_indices (Sequence | int): Output from which network position. + Index 0-6 respectively corresponds to + [stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4] + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to -1, which means not freezing any parameters. + deploy (bool): Whether to switch the model structure to + deployment mode. Default: False. + init_cfg (dict, optional): Initialization config dict + """ # noqa: E501 + + # --layers: [x,x,x,x], numbers of layers for the four stages + # --embed_dims, --mlp_ratios: + # embedding dims and mlp ratios for the four stages + # --downsamples: flags to apply downsampling or not in four blocks + arch_settings = { + 's12': { + 'layers': [2, 2, 6, 2], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's24': { + 'layers': [4, 4, 12, 4], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-5, + }, + 's36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [64, 128, 320, 512], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm36': { + 'layers': [6, 6, 18, 6], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + 'm48': { + 'layers': [8, 8, 24, 8], + 'embed_dims': [96, 192, 384, 768], + 'mlp_ratios': [4, 4, 4, 4], + 'layer_scale_init_value': 1e-6, + }, + } + + def __init__(self, + arch='s12', + in_channels=3, + norm_cfg=dict(type='GN', num_groups=1), + act_cfg=dict(type='GELU'), + in_patch_size=7, + in_stride=4, + in_pad=2, + down_patch_size=3, + down_stride=2, + down_pad=1, + drop_rate=0., + drop_path_rate=0., + out_indices=-1, + frozen_stages=-1, + init_cfg=None, + deploy=False): + + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'layers' in arch and 'embed_dims' in arch, \ + f'The arch dict must have "layers" and "embed_dims", ' \ + f'but got {list(arch.keys())}.' + + layers = arch['layers'] + embed_dims = arch['embed_dims'] + mlp_ratios = arch['mlp_ratios'] \ + if 'mlp_ratios' in arch else [4, 4, 4, 4] + layer_scale_init_value = arch['layer_scale_init_value'] \ + if 'layer_scale_init_value' in arch else 1e-5 + + self.patch_embed = PatchEmbed( + patch_size=in_patch_size, + stride=in_stride, + padding=in_pad, + in_chans=in_channels, + embed_dim=embed_dims[0]) + + # set the main block in network + network = [] + for i in range(len(layers)): + stage = basic_blocks( + embed_dims[i], + i, + layers, + mlp_ratio=mlp_ratios[i], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + deploy=deploy) + network.append(stage) + if i >= len(layers) - 1: + break + if embed_dims[i] != embed_dims[i + 1]: + # downsampling between two stages + network.append( + PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + padding=down_pad, + in_chans=embed_dims[i], + embed_dim=embed_dims[i + 1])) + + self.network = nn.ModuleList(network) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 7 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + if self.out_indices: + for i_layer in self.out_indices: + layer = build_norm_layer(norm_cfg, + embed_dims[(i_layer + 1) // 2])[1] + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.frozen_stages = frozen_stages + self._freeze_stages() + self.deploy = deploy + + def forward_embeddings(self, x): + x = self.patch_embed(x) + return x + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + x = block(x) + if idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + x_out = norm_layer(x) + outs.append(x_out) + return tuple(outs) + + def forward(self, x): + # input embedding + x = self.forward_embeddings(x) + # through backbone + x = self.forward_tokens(x) + return x + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + # Include both block and downsample layer. + module = self.network[i] + module.eval() + for param in module.parameters(): + param.requires_grad = False + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(RIFormer, self).train(mode) + self._freeze_stages() + return self + + def switch_to_deploy(self): + for m in self.modules(): + if isinstance(m, RIFormerBlock): + m.switch_to_deploy() + self.deploy = True diff --git a/mmpretrain/models/backbones/seresnet.py b/mmpretrain/models/backbones/seresnet.py new file mode 100644 index 0000000..4437c17 --- /dev/null +++ b/mmpretrain/models/backbones/seresnet.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.utils.checkpoint as cp + +from mmpretrain.registry import MODELS +from ..utils.se_layer import SELayer +from .resnet import Bottleneck, ResLayer, ResNet + + +class SEBottleneck(Bottleneck): + """SEBottleneck block for SEResNet. + + Args: + in_channels (int): The input channels of the SEBottleneck block. + out_channels (int): The output channel of the SEBottleneck block. + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + """ + + def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs): + super(SEBottleneck, self).__init__(in_channels, out_channels, **kwargs) + self.se_layer = SELayer(out_channels, ratio=se_ratio) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + out = self.se_layer(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class SEResNet(ResNet): + """SEResNet backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + se_ratio (int): Squeeze ratio in SELayer. Default: 16. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + + Example: + >>> from mmpretrain.models import SEResNet + >>> import torch + >>> self = SEResNet(depth=50) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 224, 224) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 56, 56) + (1, 128, 28, 28) + (1, 256, 14, 14) + (1, 512, 7, 7) + """ + + arch_settings = { + 50: (SEBottleneck, (3, 4, 6, 3)), + 101: (SEBottleneck, (3, 4, 23, 3)), + 152: (SEBottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, se_ratio=16, **kwargs): + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for SEResNet') + self.se_ratio = se_ratio + super(SEResNet, self).__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer(se_ratio=self.se_ratio, **kwargs) diff --git a/mmpretrain/models/backbones/seresnext.py b/mmpretrain/models/backbones/seresnext.py new file mode 100644 index 0000000..6a28380 --- /dev/null +++ b/mmpretrain/models/backbones/seresnext.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmpretrain.registry import MODELS +from .resnet import ResLayer +from .seresnet import SEBottleneck as _SEBottleneck +from .seresnet import SEResNet + + +class SEBottleneck(_SEBottleneck): + """SEBottleneck block for SEResNeXt. + + Args: + in_channels (int): Input channels of this block. + out_channels (int): Output channels of this block. + base_channels (int): Middle channels of the first stage. Default: 64. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + stride (int): stride of the block. Default: 1 + dilation (int): dilation of convolution. Default: 1 + downsample (nn.Module, optional): downsample operation on identity + branch. Default: None + se_ratio (int): Squeeze ratio in SELayer. Default: 16 + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + conv_cfg (dict, optional): dictionary to construct and config conv + layer. Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + in_channels, + out_channels, + base_channels=64, + groups=32, + width_per_group=4, + se_ratio=16, + **kwargs): + super(SEBottleneck, self).__init__(in_channels, out_channels, se_ratio, + **kwargs) + self.groups = groups + self.width_per_group = width_per_group + + # We follow the same rational of ResNext to compute mid_channels. + # For SEResNet bottleneck, middle channels are determined by expansion + # and out_channels, but for SEResNeXt bottleneck, it is determined by + # groups and width_per_group and the stage it is located in. + if groups != 1: + assert self.mid_channels % base_channels == 0 + self.mid_channels = ( + groups * width_per_group * self.mid_channels // base_channels) + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, self.mid_channels, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.out_channels, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.in_channels, + self.mid_channels, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.mid_channels, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + self.mid_channels, + self.out_channels, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class SEResNeXt(SEResNet): + """SEResNeXt backbone. + + Please refer to the `paper `__ for + details. + + Args: + depth (int): Network depth, from {50, 101, 152}. + groups (int): Groups of conv2 in Bottleneck. Default: 32. + width_per_group (int): Width per group of conv2 in Bottleneck. + Default: 4. + se_ratio (int): Squeeze ratio in SELayer. Default: 16. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Output channels of the stem layer. Default: 64. + num_stages (int): Stages of the network. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Default: ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. Default: ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): The config dict for conv layers. Default: None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + """ + + arch_settings = { + 50: (SEBottleneck, (3, 4, 6, 3)), + 101: (SEBottleneck, (3, 4, 23, 3)), + 152: (SEBottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, groups=32, width_per_group=4, **kwargs): + self.groups = groups + self.width_per_group = width_per_group + super(SEResNeXt, self).__init__(depth, **kwargs) + + def make_res_layer(self, **kwargs): + return ResLayer( + groups=self.groups, + width_per_group=self.width_per_group, + base_channels=self.base_channels, + **kwargs) diff --git a/mmpretrain/models/backbones/shufflenet_v1.py b/mmpretrain/models/backbones/shufflenet_v1.py new file mode 100644 index 0000000..2cc3617 --- /dev/null +++ b/mmpretrain/models/backbones/shufflenet_v1.py @@ -0,0 +1,321 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_activation_layer +from mmengine.model import BaseModule +from mmengine.model.weight_init import constant_init, normal_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.utils import channel_shuffle, make_divisible +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class ShuffleUnit(BaseModule): + """ShuffleUnit block. + + ShuffleNet unit with pointwise group convolution (GConv) and channel + shuffle. + + Args: + in_channels (int): The input channels of the ShuffleUnit. + out_channels (int): The output channels of the ShuffleUnit. + groups (int): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. Default: 3 + first_block (bool): Whether it is the first ShuffleUnit of a + sequential ShuffleUnits. Default: True, which means not using the + grouped 1x1 convolution. + combine (str): The ways to combine the input and output + branches. Default: 'add'. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + groups=3, + first_block=True, + combine='add', + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super(ShuffleUnit, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.first_block = first_block + self.combine = combine + self.groups = groups + self.bottleneck_channels = self.out_channels // 4 + self.with_cp = with_cp + + if self.combine == 'add': + self.depthwise_stride = 1 + self._combine_func = self._add + assert in_channels == out_channels, ( + 'in_channels must be equal to out_channels when combine ' + 'is add') + elif self.combine == 'concat': + self.depthwise_stride = 2 + self._combine_func = self._concat + self.out_channels -= self.in_channels + self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + else: + raise ValueError(f'Cannot combine tensors with {self.combine}. ' + 'Only "add" and "concat" are supported') + + self.first_1x1_groups = 1 if first_block else self.groups + self.g_conv_1x1_compress = ConvModule( + in_channels=self.in_channels, + out_channels=self.bottleneck_channels, + kernel_size=1, + groups=self.first_1x1_groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.depthwise_conv3x3_bn = ConvModule( + in_channels=self.bottleneck_channels, + out_channels=self.bottleneck_channels, + kernel_size=3, + stride=self.depthwise_stride, + padding=1, + groups=self.bottleneck_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.g_conv_1x1_expand = ConvModule( + in_channels=self.bottleneck_channels, + out_channels=self.out_channels, + kernel_size=1, + groups=self.groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + self.act = build_activation_layer(act_cfg) + + @staticmethod + def _add(x, out): + # residual connection + return x + out + + @staticmethod + def _concat(x, out): + # concatenate along channel axis + return torch.cat((x, out), 1) + + def forward(self, x): + + def _inner_forward(x): + residual = x + + out = self.g_conv_1x1_compress(x) + out = self.depthwise_conv3x3_bn(out) + + if self.groups > 1: + out = channel_shuffle(out, self.groups) + + out = self.g_conv_1x1_expand(out) + + if self.combine == 'concat': + residual = self.avgpool(residual) + out = self.act(out) + out = self._combine_func(residual, out) + else: + out = self._combine_func(residual, out) + out = self.act(out) + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class ShuffleNetV1(BaseBackbone): + """ShuffleNetV1 backbone. + + Args: + groups (int): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. Default: 3. + widen_factor (float): Width multiplier - adjusts the number + of channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (2, ) + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + groups=3, + widen_factor=1.0, + out_indices=(2, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False, + init_cfg=None): + super(ShuffleNetV1, self).__init__(init_cfg) + self.init_cfg = init_cfg + self.stage_blocks = [4, 8, 4] + self.groups = groups + + for index in out_indices: + if index not in range(0, 3): + raise ValueError('the item in out_indices must in ' + f'range(0, 3). But received {index}') + + if frozen_stages not in range(-1, 3): + raise ValueError('frozen_stages must be in range(-1, 3). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + if groups == 1: + channels = (144, 288, 576) + elif groups == 2: + channels = (200, 400, 800) + elif groups == 3: + channels = (240, 480, 960) + elif groups == 4: + channels = (272, 544, 1088) + elif groups == 8: + channels = (384, 768, 1536) + else: + raise ValueError(f'{groups} groups is not supported for 1x1 ' + 'Grouped Convolutions') + + channels = [make_divisible(ch * widen_factor, 8) for ch in channels] + + self.in_channels = int(24 * widen_factor) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layers = nn.ModuleList() + for i, num_blocks in enumerate(self.stage_blocks): + first_block = True if i == 0 else False + layer = self.make_layer(channels[i], num_blocks, first_block) + self.layers.append(layer) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(self.frozen_stages): + layer = self.layers[i] + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def init_weights(self): + super(ShuffleNetV1, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'conv1' in name: + normal_init(m, mean=0, std=0.01) + else: + normal_init(m, mean=0, std=1.0 / m.weight.shape[1]) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, val=1, bias=0.0001) + if isinstance(m, _BatchNorm): + if m.running_mean is not None: + nn.init.constant_(m.running_mean, 0) + + def make_layer(self, out_channels, num_blocks, first_block=False): + """Stack ShuffleUnit blocks to make a layer. + + Args: + out_channels (int): out_channels of the block. + num_blocks (int): Number of blocks. + first_block (bool): Whether is the first ShuffleUnit of a + sequential ShuffleUnits. Default: False, which means using + the grouped 1x1 convolution. + """ + layers = [] + for i in range(num_blocks): + first_block = first_block if i == 0 else False + combine_mode = 'concat' if i == 0 else 'add' + layers.append( + ShuffleUnit( + self.in_channels, + out_channels, + groups=self.groups, + first_block=first_block, + combine=combine_mode, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def train(self, mode=True): + super(ShuffleNetV1, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/shufflenet_v2.py b/mmpretrain/models/backbones/shufflenet_v2.py new file mode 100644 index 0000000..02f9c74 --- /dev/null +++ b/mmpretrain/models/backbones/shufflenet_v2.py @@ -0,0 +1,305 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.model.weight_init import constant_init, normal_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.utils import channel_shuffle +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class InvertedResidual(BaseModule): + """InvertedResidual block for ShuffleNetV2 backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + stride (int): Stride of the 3x3 convolution layer. Default: 1 + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False, + init_cfg=None): + super(InvertedResidual, self).__init__(init_cfg) + self.stride = stride + self.with_cp = with_cp + + branch_features = out_channels // 2 + if self.stride == 1: + assert in_channels == branch_features * 2, ( + f'in_channels ({in_channels}) should equal to ' + f'branch_features * 2 ({branch_features * 2}) ' + 'when stride is 1') + + if in_channels != branch_features * 2: + assert self.stride != 1, ( + f'stride ({self.stride}) should not equal 1 when ' + f'in_channels != branch_features * 2') + + if self.stride > 1: + self.branch1 = nn.Sequential( + ConvModule( + in_channels, + in_channels, + kernel_size=3, + stride=self.stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + in_channels, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + + self.branch2 = nn.Sequential( + ConvModule( + in_channels if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1, + groups=branch_features, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + + def _inner_forward(x): + if self.stride > 1: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + else: + # Channel Split operation. using these lines of code to replace + # ``chunk(x, 2, dim=1)`` can make it easier to deploy a + # shufflenetv2 model by using mmdeploy. + channels = x.shape[1] + c = channels // 2 + channels % 2 + x1 = x[:, :c, :, :] + x2 = x[:, c:, :, :] + + out = torch.cat((x1, self.branch2(x2)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@MODELS.register_module() +class ShuffleNetV2(BaseBackbone): + """ShuffleNetV2 backbone. + + Args: + widen_factor (float): Width multiplier - adjusts the number of + channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (0, 1, 2, 3). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict, optional): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + widen_factor=1.0, + out_indices=(3, ), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False, + init_cfg=None): + super(ShuffleNetV2, self).__init__(init_cfg) + self.stage_blocks = [4, 8, 4] + for index in out_indices: + if index not in range(0, 4): + raise ValueError('the item in out_indices must in ' + f'range(0, 4). But received {index}') + + if frozen_stages not in range(-1, 4): + raise ValueError('frozen_stages must be in range(-1, 4). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + if widen_factor == 0.5: + channels = [48, 96, 192, 1024] + elif widen_factor == 1.0: + channels = [116, 232, 464, 1024] + elif widen_factor == 1.5: + channels = [176, 352, 704, 1024] + elif widen_factor == 2.0: + channels = [244, 488, 976, 2048] + else: + raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. ' + f'But received {widen_factor}') + + self.in_channels = 24 + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layers = nn.ModuleList() + for i, num_blocks in enumerate(self.stage_blocks): + layer = self._make_layer(channels[i], num_blocks) + self.layers.append(layer) + + output_channels = channels[-1] + self.layers.append( + ConvModule( + in_channels=self.in_channels, + out_channels=output_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def _make_layer(self, out_channels, num_blocks): + """Stack blocks to make a layer. + + Args: + out_channels (int): out_channels of the block. + num_blocks (int): number of blocks. + """ + layers = [] + for i in range(num_blocks): + stride = 2 if i == 0 else 1 + layers.append( + InvertedResidual( + in_channels=self.in_channels, + out_channels=out_channels, + stride=stride, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + super(ShuffleNetV2, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'conv1' in name: + normal_init(m, mean=0, std=0.01) + else: + normal_init(m, mean=0, std=1.0 / m.weight.shape[1]) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m.weight, val=1, bias=0.0001) + if isinstance(m, _BatchNorm): + if m.running_mean is not None: + nn.init.constant_(m.running_mean, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def train(self, mode=True): + super(ShuffleNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/mmpretrain/models/backbones/sparse_convnext.py b/mmpretrain/models/backbones/sparse_convnext.py new file mode 100644 index 0000000..8f36136 --- /dev/null +++ b/mmpretrain/models/backbones/sparse_convnext.py @@ -0,0 +1,298 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmengine.model import ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import (SparseAvgPooling, SparseConv2d, SparseHelper, + SparseMaxPooling, build_norm_layer) +from .convnext import ConvNeXt, ConvNeXtBlock + + +class SparseConvNeXtBlock(ConvNeXtBlock): + """Sparse ConvNeXt Block. + + Note: + There are two equivalent implementations: + 1. DwConv -> SparseLayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv; + all outputs are in (N, C, H, W). + 2. DwConv -> SparseLayerNorm -> Permute to (N, H, W, C) -> Linear -> + GELU -> Linear; Permute back + As default, we use the second to align with the official repository. + And it may be slightly faster. + """ + + def forward(self, x): + + def _inner_forward(x): + shortcut = x + x = self.depthwise_conv(x) + + if self.linear_pw_conv: + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x, data_format='channel_last') + x = self.pointwise_conv1(x) + x = self.act(x) + if self.grn is not None: + x = self.grn(x, data_format='channel_last') + x = self.pointwise_conv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + else: + x = self.norm(x, data_format='channel_first') + x = self.pointwise_conv1(x) + x = self.act(x) + + if self.grn is not None: + x = self.grn(x, data_format='channel_first') + x = self.pointwise_conv2(x) + + if self.gamma is not None: + x = x.mul(self.gamma.view(1, -1, 1, 1)) + + x *= SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=True) + + x = shortcut + self.drop_path(x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class SparseConvNeXt(ConvNeXt): + """ConvNeXt with sparse module conversion function. + + Modified from + https://github.com/keyu-tian/SparK/blob/main/models/convnext.py + and + https://github.com/keyu-tian/SparK/blob/main/encoder.py + To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``ConvNeXt.arch_settings``. And if dict, it + should include the following two keys: + - depths (list[int]): Number of blocks at each stage. + - channels (list[int]): The number of channels at each stage. + Defaults to 'tiny'. + in_channels (int): Number of input image channels. Defaults to 3. + stem_patch_size (int): The size of one patch in the stem layer. + Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='SparseLN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. Defaults to True. + use_grn (bool): Whether to add Global Response Normalization in the + blocks. Defaults to False. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-6. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + gap_before_output (bool): Whether to globally average the feature + map before the final norm layer. In the official repo, it's only + used in classification task. Defaults to True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): Initialization config dict. + """ # noqa: E501 + + def __init__(self, + arch: str = 'small', + in_channels: int = 3, + stem_patch_size: int = 4, + norm_cfg: dict = dict(type='SparseLN2d', eps=1e-6), + act_cfg: dict = dict(type='GELU'), + linear_pw_conv: bool = True, + use_grn: bool = False, + drop_path_rate: float = 0, + layer_scale_init_value: float = 1e-6, + out_indices: int = -1, + frozen_stages: int = 0, + gap_before_output: bool = True, + with_cp: bool = False, + init_cfg: Optional[Union[dict, List[dict]]] = [ + dict( + type='TruncNormal', + layer=['Conv2d', 'Linear'], + std=.02, + bias=0.), + dict( + type='Constant', layer=['LayerNorm'], val=1., + bias=0.), + ]): + super(ConvNeXt, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'depths' in arch and 'channels' in arch, \ + f'The arch dict must have "depths" and "channels", ' \ + f'but got {list(arch.keys())}.' + + self.depths = arch['depths'] + self.channels = arch['channels'] + assert (isinstance(self.depths, Sequence) + and isinstance(self.channels, Sequence) + and len(self.depths) == len(self.channels)), \ + f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \ + 'should be both sequence with the same length.' + + self.num_stages = len(self.depths) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_output = gap_before_output + + # 4 downsample layers between stages, including the stem layer. + self.downsample_layers = ModuleList() + stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.channels[0], + kernel_size=stem_patch_size, + stride=stem_patch_size), + build_norm_layer(norm_cfg, self.channels[0]), + ) + self.downsample_layers.append(stem) + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + block_idx = 0 + + # 4 feature resolution stages, each consisting of multiple residual + # blocks + self.stages = nn.ModuleList() + for i in range(self.num_stages): + depth = self.depths[i] + channels = self.channels[i] + + if i >= 1: + downsample_layer = nn.Sequential( + build_norm_layer(norm_cfg, self.channels[i - 1]), + nn.Conv2d( + self.channels[i - 1], + channels, + kernel_size=2, + stride=2), + ) + self.downsample_layers.append(downsample_layer) + + stage = Sequential(*[ + SparseConvNeXtBlock( + in_channels=channels, + drop_path_rate=dpr[block_idx + j], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + linear_pw_conv=linear_pw_conv, + layer_scale_init_value=layer_scale_init_value, + use_grn=use_grn, + with_cp=with_cp) for j in range(depth) + ]) + block_idx += depth + + self.stages.append(stage) + + self.dense_model_to_sparse(m=self) + + def forward(self, x): + outs = [] + for i, stage in enumerate(self.stages): + x = self.downsample_layers[i](x) + x = stage(x) + if i in self.out_indices: + if self.gap_before_output: + gap = x.mean([-2, -1], keepdim=True) + outs.append(gap.flatten(1)) + else: + outs.append(x) + + return tuple(outs) + + def dense_model_to_sparse(self, m: nn.Module) -> nn.Module: + """Convert regular dense modules to sparse modules.""" + output = m + if isinstance(m, nn.Conv2d): + m: nn.Conv2d + bias = m.bias is not None + output = SparseConv2d( + m.in_channels, + m.out_channels, + kernel_size=m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=bias, + padding_mode=m.padding_mode, + ) + output.weight.data.copy_(m.weight.data) + if bias: + output.bias.data.copy_(m.bias.data) + + elif isinstance(m, nn.MaxPool2d): + m: nn.MaxPool2d + output = SparseMaxPooling( + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + return_indices=m.return_indices, + ceil_mode=m.ceil_mode) + + elif isinstance(m, nn.AvgPool2d): + m: nn.AvgPool2d + output = SparseAvgPooling( + m.kernel_size, + m.stride, + m.padding, + ceil_mode=m.ceil_mode, + count_include_pad=m.count_include_pad, + divisor_override=m.divisor_override) + + # elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): + # m: nn.BatchNorm2d + # output = (SparseSyncBatchNorm2d + # if enable_sync_bn else SparseBatchNorm2d)( + # m.weight.shape[0], + # eps=m.eps, + # momentum=m.momentum, + # affine=m.affine, + # track_running_stats=m.track_running_stats) + # output.weight.data.copy_(m.weight.data) + # output.bias.data.copy_(m.bias.data) + # output.running_mean.data.copy_(m.running_mean.data) + # output.running_var.data.copy_(m.running_var.data) + # output.num_batches_tracked.data.copy_(m.num_batches_tracked.data) + + for name, child in m.named_children(): + output.add_module(name, self.dense_model_to_sparse(child)) + del m + return output diff --git a/mmpretrain/models/backbones/sparse_resnet.py b/mmpretrain/models/backbones/sparse_resnet.py new file mode 100644 index 0000000..67597f1 --- /dev/null +++ b/mmpretrain/models/backbones/sparse_resnet.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import Optional, Tuple + +import torch.nn as nn + +from mmpretrain.models.utils.sparse_modules import (SparseAvgPooling, + SparseBatchNorm2d, + SparseConv2d, + SparseMaxPooling, + SparseSyncBatchNorm2d) +from mmpretrain.registry import MODELS +from .resnet import ResNet + + +@MODELS.register_module() +class SparseResNet(ResNet): + """ResNet with sparse module conversion function. + + Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Defaults to 3. + stem_channels (int): Output channels of the stem layer. Defaults to 64. + base_channels (int): Middle channels of the first stage. + Defaults to 64. + num_stages (int): Stages of the network. Defaults to 4. + strides (Sequence[int]): Strides of the first block of each stage. + Defaults to ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Defaults to ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. + Defaults to ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Defaults to False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + conv_cfg (dict | None): The config dict for conv layers. + Defaults to None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Defaults to True. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + """ + + def __init__(self, + depth: int, + in_channels: int = 3, + stem_channels: int = 64, + base_channels: int = 64, + expansion: Optional[int] = None, + num_stages: int = 4, + strides: Tuple[int] = (1, 2, 2, 2), + dilations: Tuple[int] = (1, 1, 1, 1), + out_indices: Tuple[int] = (3, ), + style: str = 'pytorch', + deep_stem: bool = False, + avg_down: bool = False, + frozen_stages: int = -1, + conv_cfg: Optional[dict] = None, + norm_cfg: dict = dict(type='SparseSyncBatchNorm2d'), + norm_eval: bool = False, + with_cp: bool = False, + zero_init_residual: bool = False, + init_cfg: Optional[dict] = [ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ], + drop_path_rate: float = 0, + **kwargs): + super().__init__( + depth=depth, + in_channels=in_channels, + stem_channels=stem_channels, + base_channels=base_channels, + expansion=expansion, + num_stages=num_stages, + strides=strides, + dilations=dilations, + out_indices=out_indices, + style=style, + deep_stem=deep_stem, + avg_down=avg_down, + frozen_stages=frozen_stages, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + norm_eval=norm_eval, + with_cp=with_cp, + zero_init_residual=zero_init_residual, + init_cfg=init_cfg, + drop_path_rate=drop_path_rate, + **kwargs) + norm_type = norm_cfg['type'] + enable_sync_bn = False + if re.search('Sync', norm_type) is not None: + enable_sync_bn = True + self.dense_model_to_sparse(m=self, enable_sync_bn=enable_sync_bn) + + def dense_model_to_sparse(self, m: nn.Module, + enable_sync_bn: bool) -> nn.Module: + """Convert regular dense modules to sparse modules.""" + output = m + if isinstance(m, nn.Conv2d): + m: nn.Conv2d + bias = m.bias is not None + output = SparseConv2d( + m.in_channels, + m.out_channels, + kernel_size=m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=bias, + padding_mode=m.padding_mode, + ) + output.weight.data.copy_(m.weight.data) + if bias: + output.bias.data.copy_(m.bias.data) + + elif isinstance(m, nn.MaxPool2d): + m: nn.MaxPool2d + output = SparseMaxPooling( + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + return_indices=m.return_indices, + ceil_mode=m.ceil_mode) + + elif isinstance(m, nn.AvgPool2d): + m: nn.AvgPool2d + output = SparseAvgPooling( + m.kernel_size, + m.stride, + m.padding, + ceil_mode=m.ceil_mode, + count_include_pad=m.count_include_pad, + divisor_override=m.divisor_override) + + elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): + m: nn.BatchNorm2d + output = (SparseSyncBatchNorm2d + if enable_sync_bn else SparseBatchNorm2d)( + m.weight.shape[0], + eps=m.eps, + momentum=m.momentum, + affine=m.affine, + track_running_stats=m.track_running_stats) + output.weight.data.copy_(m.weight.data) + output.bias.data.copy_(m.bias.data) + output.running_mean.data.copy_(m.running_mean.data) + output.running_var.data.copy_(m.running_var.data) + output.num_batches_tracked.data.copy_(m.num_batches_tracked.data) + + elif isinstance(m, (nn.Conv1d, )): + raise NotImplementedError + + for name, child in m.named_children(): + output.add_module( + name, + self.dense_model_to_sparse( + child, enable_sync_bn=enable_sync_bn)) + del m + return output diff --git a/mmpretrain/models/backbones/swin_transformer.py b/mmpretrain/models/backbones/swin_transformer.py new file mode 100644 index 0000000..559fd5e --- /dev/null +++ b/mmpretrain/models/backbones/swin_transformer.py @@ -0,0 +1,585 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import (ShiftWindowMSA, resize_pos_embed, + resize_relative_position_bias_table, to_2tuple) +from .base_backbone import BaseBackbone + + +class SwinBlock(BaseModule): + """Swin Transformer block. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + shift (bool): Shift the attention window or not. Defaults to False. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=7, + shift=False, + ffn_ratio=4., + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(SwinBlock, self).__init__(init_cfg) + self.with_cp = with_cp + + _attn_cfgs = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'shift_size': window_size // 2 if shift else 0, + 'window_size': window_size, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'pad_small_map': pad_small_map, + **attn_cfgs + } + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA(**_attn_cfgs) + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + **ffn_cfgs + } + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN(**_ffn_cfgs) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockSequence(BaseModule): + """Module with successive Swin Transformer blocks and downsample layer. + + Args: + embed_dims (int): Number of input channels. + depth (int): Number of successive swin transformer blocks. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + downsample (bool): Downsample the output of blocks by patch merging. + Defaults to False. + downsample_cfg (dict): The extra config of the patch merging layer. + Defaults to empty dict. + drop_paths (Sequence[float] | float): The drop path rate in each block. + Defaults to 0. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + depth, + num_heads, + window_size=7, + downsample=False, + downsample_cfg=dict(), + drop_paths=0., + block_cfgs=dict(), + with_cp=False, + pad_small_map=False, + init_cfg=None): + super().__init__(init_cfg) + + if not isinstance(drop_paths, Sequence): + drop_paths = [drop_paths] * depth + + if not isinstance(block_cfgs, Sequence): + block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] + + self.embed_dims = embed_dims + self.blocks = ModuleList() + for i in range(depth): + _block_cfg = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'window_size': window_size, + 'shift': False if i % 2 == 0 else True, + 'drop_path': drop_paths[i], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **block_cfgs[i] + } + block = SwinBlock(**_block_cfg) + self.blocks.append(block) + + if downsample: + _downsample_cfg = { + 'in_channels': embed_dims, + 'out_channels': 2 * embed_dims, + 'norm_cfg': dict(type='LN'), + **downsample_cfg + } + self.downsample = PatchMerging(**_downsample_cfg) + else: + self.downsample = None + + def forward(self, x, in_shape, do_downsample=True): + for block in self.blocks: + x = block(x, in_shape) + + if self.downsample is not None and do_downsample: + x, out_shape = self.downsample(x, in_shape) + else: + out_shape = in_shape + return x, out_shape + + @property + def out_channels(self): + if self.downsample: + return self.downsample.out_channels + else: + return self.embed_dims + + +@MODELS.register_module() +class SwinTransformer(BaseBackbone): + """Swin Transformer. + + A PyTorch implement of : `Swin Transformer: + Hierarchical Vision Transformer using Shifted Windows + `_ + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + arch (str | dict): Swin Transformer architecture. If use string, choose + from 'tiny', 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (List[int]): The number of heads in attention + modules of each stage. + + Defaults to 'tiny'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + in_channels (int): The num of input channels. Defaults to 3. + window_size (int): The height and width of the window. Defaults to 7. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + out_after_downsample (bool): Whether to output the feature map of a + stage after the following downsample layer. Defaults to False. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + interpolate_mode (str): Select the interpolate mode for absolute + position embeding vector resize. Defaults to "bicubic". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + stage_cfgs (Sequence[dict] | dict): Extra config dict for each + stage. Defaults to an empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import SwinTransformer + >>> import torch + >>> extra_config = dict( + >>> arch='tiny', + >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, + >>> 'expansion_ratio': 3})) + >>> self = SwinTransformer(**extra_config) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> output = self.forward(inputs) + >>> print(output.shape) + (1, 2592, 4) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 96, + 'depths': [2, 2, 6, 2], + 'num_heads': [3, 6, 12, 24]}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 96, + 'depths': [2, 2, 18, 2], + 'num_heads': [3, 6, 12, 24]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48]}), + } # yapf: disable + + _version = 3 + num_extra_tokens = 0 + + def __init__(self, + arch='tiny', + img_size=224, + patch_size=4, + in_channels=3, + window_size=7, + drop_rate=0., + drop_path_rate=0.1, + out_indices=(3, ), + out_after_downsample=False, + use_abs_pos_embed=False, + interpolate_mode='bicubic', + with_cp=False, + frozen_stages=-1, + norm_eval=False, + pad_small_map=False, + norm_cfg=dict(type='LN'), + stage_cfgs=dict(), + patch_cfg=dict(), + init_cfg=None): + super(SwinTransformer, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'num_heads'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.num_layers = len(self.depths) + self.out_indices = out_indices + self.out_after_downsample = out_after_downsample + self.use_abs_pos_embed = use_abs_pos_embed + self.interpolate_mode = interpolate_mode + self.frozen_stages = frozen_stages + + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + if self.use_abs_pos_embed: + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook( + self._prepare_abs_pos_embed) + + self._register_load_state_dict_pre_hook( + self._prepare_relative_position_bias_table) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + self.norm_eval = norm_eval + + # stochastic depth + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + self.stages = ModuleList() + embed_dims = [self.embed_dims] + for i, (depth, + num_heads) in enumerate(zip(self.depths, self.num_heads)): + if isinstance(stage_cfgs, Sequence): + stage_cfg = stage_cfgs[i] + else: + stage_cfg = deepcopy(stage_cfgs) + downsample = True if i < self.num_layers - 1 else False + _stage_cfg = { + 'embed_dims': embed_dims[-1], + 'depth': depth, + 'num_heads': num_heads, + 'window_size': window_size, + 'downsample': downsample, + 'drop_paths': dpr[:depth], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + **stage_cfg + } + + stage = SwinBlockSequence(**_stage_cfg) + self.stages.append(stage) + + dpr = dpr[depth:] + embed_dims.append(stage.out_channels) + + if self.out_after_downsample: + self.num_features = embed_dims[1:] + else: + self.num_features = embed_dims[:-1] + + for i in out_indices: + if norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, + self.num_features[i])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm{i}', norm_layer) + + def init_weights(self): + super(SwinTransformer, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.absolute_pos_embed, self.patch_resolution, hw_shape, + self.interpolate_mode, self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage( + x, hw_shape, do_downsample=self.out_after_downsample) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + if stage.downsample is not None and not self.out_after_downsample: + x, hw_shape = stage.downsample(x, hw_shape) + + return tuple(outs) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args, + **kwargs): + """load checkpoints.""" + # Names of some parameters in has been changed. + version = local_metadata.get('version', None) + if (version is None + or version < 2) and self.__class__ is SwinTransformer: + final_stage_num = len(self.stages) - 1 + state_dict_keys = list(state_dict.keys()) + for k in state_dict_keys: + if k.startswith('norm.') or k.startswith('backbone.norm.'): + convert_key = k.replace('norm.', f'norm{final_stage_num}.') + state_dict[convert_key] = state_dict[k] + del state_dict[k] + if (version is None + or version < 3) and self.__class__ is SwinTransformer: + state_dict_keys = list(state_dict.keys()) + for k in state_dict_keys: + if 'attn_mask' in k: + del state_dict[k] + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + *args, **kwargs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + for i in self.out_indices: + if i <= self.frozen_stages: + for param in getattr(self, f'norm{i}').parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(SwinTransformer, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'absolute_pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.absolute_pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + 'Resize the absolute_pos_embed shape from ' + f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + def _prepare_relative_position_bias_table(self, state_dict, prefix, *args, + **kwargs): + state_dict_model = self.state_dict() + all_keys = list(state_dict_model.keys()) + for key in all_keys: + if 'relative_position_bias_table' in key: + ckpt_key = prefix + key + if ckpt_key not in state_dict: + continue + relative_position_bias_table_pretrained = state_dict[ckpt_key] + relative_position_bias_table_current = state_dict_model[key] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + if L1 != L2: + src_size = int(L1**0.5) + dst_size = int(L2**0.5) + new_rel_pos_bias = resize_relative_position_bias_table( + src_size, dst_size, + relative_position_bias_table_pretrained, nH1) + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info('Resize the relative_position_bias_table from ' + f'{state_dict[ckpt_key].shape} to ' + f'{new_rel_pos_bias.shape}') + state_dict[ckpt_key] = new_rel_pos_bias + + # The index buffer need to be re-generated. + index_buffer = ckpt_key.replace('bias_table', 'index') + del state_dict[index_buffer] + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = sum(self.depths) + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('stages'): + stage_id = int(param_name.split('.')[1]) + block_id = param_name.split('.')[3] + if block_id in ('reduction', 'norm'): + layer_depth = sum(self.depths[:stage_id + 1]) + else: + layer_depth = sum(self.depths[:stage_id]) + int(block_id) + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/swin_transformer_v2.py b/mmpretrain/models/backbones/swin_transformer_v2.py new file mode 100644 index 0000000..142505a --- /dev/null +++ b/mmpretrain/models/backbones/swin_transformer_v2.py @@ -0,0 +1,567 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from ..builder import MODELS +from ..utils import (PatchMerging, ShiftWindowMSA, WindowMSAV2, + resize_pos_embed, to_2tuple) +from .base_backbone import BaseBackbone + + +class SwinBlockV2(BaseModule): + """Swin Transformer V2 block. Use post normalization. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + shift (bool): Shift the attention window or not. Defaults to False. + extra_norm (bool): Whether add extra norm at the end of main branch. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pretrained_window_size (int): Window size in pretrained. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=8, + shift=False, + extra_norm=False, + ffn_ratio=4., + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + pretrained_window_size=0, + init_cfg=None): + + super(SwinBlockV2, self).__init__(init_cfg) + self.with_cp = with_cp + self.extra_norm = extra_norm + + _attn_cfgs = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'shift_size': window_size // 2 if shift else 0, + 'window_size': window_size, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'pad_small_map': pad_small_map, + **attn_cfgs + } + # use V2 attention implementation + _attn_cfgs.update( + window_msa=WindowMSAV2, + pretrained_window_size=to_2tuple(pretrained_window_size)) + self.attn = ShiftWindowMSA(**_attn_cfgs) + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + 'add_identity': False, + **ffn_cfgs + } + self.ffn = FFN(**_ffn_cfgs) + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + # add extra norm for every n blocks in huge and giant model + if self.extra_norm: + self.norm3 = build_norm_layer(norm_cfg, embed_dims)[1] + + def forward(self, x, hw_shape): + + def _inner_forward(x): + # Use post normalization + identity = x + x = self.attn(x, hw_shape) + x = self.norm1(x) + x = x + identity + + identity = x + x = self.ffn(x) + x = self.norm2(x) + x = x + identity + + if self.extra_norm: + x = self.norm3(x) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockV2Sequence(BaseModule): + """Module with successive Swin Transformer blocks and downsample layer. + + Args: + embed_dims (int): Number of input channels. + depth (int): Number of successive swin transformer blocks. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + downsample (bool): Downsample the output of blocks by patch merging. + Defaults to False. + downsample_cfg (dict): The extra config of the patch merging layer. + Defaults to empty dict. + drop_paths (Sequence[float] | float): The drop path rate in each block. + Defaults to 0. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + extra_norm_every_n_blocks (int): Add extra norm at the end of main + branch every n blocks. Defaults to 0, which means no needs for + extra norm layer. + pretrained_window_size (int): Window size in pretrained. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + depth, + num_heads, + window_size=8, + downsample=False, + downsample_cfg=dict(), + drop_paths=0., + block_cfgs=dict(), + with_cp=False, + pad_small_map=False, + extra_norm_every_n_blocks=0, + pretrained_window_size=0, + init_cfg=None): + super().__init__(init_cfg) + + if not isinstance(drop_paths, Sequence): + drop_paths = [drop_paths] * depth + + if not isinstance(block_cfgs, Sequence): + block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] + + if downsample: + self.out_channels = 2 * embed_dims + _downsample_cfg = { + 'in_channels': embed_dims, + 'out_channels': self.out_channels, + 'norm_cfg': dict(type='LN'), + **downsample_cfg + } + self.downsample = PatchMerging(**_downsample_cfg) + else: + self.out_channels = embed_dims + self.downsample = None + + self.blocks = ModuleList() + for i in range(depth): + extra_norm = True if extra_norm_every_n_blocks and \ + (i + 1) % extra_norm_every_n_blocks == 0 else False + _block_cfg = { + 'embed_dims': self.out_channels, + 'num_heads': num_heads, + 'window_size': window_size, + 'shift': False if i % 2 == 0 else True, + 'extra_norm': extra_norm, + 'drop_path': drop_paths[i], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + 'pretrained_window_size': pretrained_window_size, + **block_cfgs[i] + } + block = SwinBlockV2(**_block_cfg) + self.blocks.append(block) + + def forward(self, x, in_shape): + if self.downsample: + x, out_shape = self.downsample(x, in_shape) + else: + out_shape = in_shape + + for block in self.blocks: + x = block(x, out_shape) + + return x, out_shape + + +@MODELS.register_module() +class SwinTransformerV2(BaseBackbone): + """Swin Transformer V2. + + A PyTorch implement of : `Swin Transformer V2: + Scaling Up Capacity and Resolution + `_ + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + arch (str | dict): Swin Transformer architecture. If use string, choose + from 'tiny', 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (List[int]): The number of heads in attention + modules of each stage. + - **extra_norm_every_n_blocks** (int): Add extra norm at the end + of main branch every n blocks. + + Defaults to 'tiny'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + in_channels (int): The num of input channels. Defaults to 3. + window_size (int | Sequence): The height and width of the window. + Defaults to 7. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + interpolate_mode (str): Select the interpolate mode for absolute + position embeding vector resize. Defaults to "bicubic". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + stage_cfgs (Sequence[dict] | dict): Extra config dict for each + stage. Defaults to an empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + pretrained_window_sizes (tuple(int)): Pretrained window sizes of + each layer. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import SwinTransformerV2 + >>> import torch + >>> extra_config = dict( + >>> arch='tiny', + >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, + >>> 'padding': 'same'})) + >>> self = SwinTransformerV2(**extra_config) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> output = self.forward(inputs) + >>> print(output.shape) + (1, 2592, 4) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 96, + 'depths': [2, 2, 6, 2], + 'num_heads': [3, 6, 12, 24], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 96, + 'depths': [2, 2, 18, 2], + 'num_heads': [3, 6, 12, 24], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48], + 'extra_norm_every_n_blocks': 0}), + # head count not certain for huge, and is employed for another + # parallel study about self-supervised learning. + **dict.fromkeys(['h', 'huge'], + {'embed_dims': 352, + 'depths': [2, 2, 18, 2], + 'num_heads': [8, 16, 32, 64], + 'extra_norm_every_n_blocks': 6}), + **dict.fromkeys(['g', 'giant'], + {'embed_dims': 512, + 'depths': [2, 2, 42, 4], + 'num_heads': [16, 32, 64, 128], + 'extra_norm_every_n_blocks': 6}), + } # yapf: disable + + _version = 1 + num_extra_tokens = 0 + + def __init__(self, + arch='tiny', + img_size=256, + patch_size=4, + in_channels=3, + window_size=8, + drop_rate=0., + drop_path_rate=0.1, + out_indices=(3, ), + use_abs_pos_embed=False, + interpolate_mode='bicubic', + with_cp=False, + frozen_stages=-1, + norm_eval=False, + pad_small_map=False, + norm_cfg=dict(type='LN'), + stage_cfgs=dict(), + patch_cfg=dict(), + pretrained_window_sizes=[0, 0, 0, 0], + init_cfg=None): + super(SwinTransformerV2, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'depths', 'num_heads', + 'extra_norm_every_n_blocks' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.extra_norm_every_n_blocks = self.arch_settings[ + 'extra_norm_every_n_blocks'] + self.num_layers = len(self.depths) + self.out_indices = out_indices + self.use_abs_pos_embed = use_abs_pos_embed + self.interpolate_mode = interpolate_mode + self.frozen_stages = frozen_stages + + if isinstance(window_size, int): + self.window_sizes = [window_size for _ in range(self.num_layers)] + elif isinstance(window_size, Sequence): + assert len(window_size) == self.num_layers, \ + f'Length of window_sizes {len(window_size)} is not equal to '\ + f'length of stages {self.num_layers}.' + self.window_sizes = window_size + else: + raise TypeError('window_size should be a Sequence or int.') + + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + if self.use_abs_pos_embed: + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook( + self._prepare_abs_pos_embed) + + self._register_load_state_dict_pre_hook(self._delete_reinit_params) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + self.norm_eval = norm_eval + + # stochastic depth + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + self.stages = ModuleList() + embed_dims = [self.embed_dims] + for i, (depth, + num_heads) in enumerate(zip(self.depths, self.num_heads)): + if isinstance(stage_cfgs, Sequence): + stage_cfg = stage_cfgs[i] + else: + stage_cfg = deepcopy(stage_cfgs) + downsample = True if i > 0 else False + _stage_cfg = { + 'embed_dims': embed_dims[-1], + 'depth': depth, + 'num_heads': num_heads, + 'window_size': self.window_sizes[i], + 'downsample': downsample, + 'drop_paths': dpr[:depth], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + 'extra_norm_every_n_blocks': self.extra_norm_every_n_blocks, + 'pretrained_window_size': pretrained_window_sizes[i], + 'downsample_cfg': dict(use_post_norm=True), + **stage_cfg + } + + stage = SwinBlockV2Sequence(**_stage_cfg) + self.stages.append(stage) + + dpr = dpr[depth:] + embed_dims.append(stage.out_channels) + + for i in out_indices: + if norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm{i}', norm_layer) + + def init_weights(self): + super(SwinTransformerV2, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.absolute_pos_embed, self.patch_resolution, hw_shape, + self.interpolate_mode, self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + stage.out_channels).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + for i in self.out_indices: + if i <= self.frozen_stages: + for param in getattr(self, f'norm{i}').parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(SwinTransformerV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'absolute_pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.absolute_pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + 'Resize the absolute_pos_embed shape from ' + f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + def _delete_reinit_params(self, state_dict, prefix, *args, **kwargs): + # delete relative_position_index since we always re-init it + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + 'Delete `relative_position_index` and `relative_coords_table` ' + 'since we always re-init these params according to the ' + '`window_size`, which might cause unwanted but unworried ' + 'warnings when loading checkpoint.') + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_position_index' in k + ] + for k in relative_position_index_keys: + del state_dict[k] + + # delete relative_coords_table since we always re-init it + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_coords_table' in k + ] + for k in relative_position_index_keys: + del state_dict[k] diff --git a/mmpretrain/models/backbones/t2t_vit.py b/mmpretrain/models/backbones/t2t_vit.py new file mode 100644 index 0000000..a57b95e --- /dev/null +++ b/mmpretrain/models/backbones/t2t_vit.py @@ -0,0 +1,447 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed, + to_2tuple) +from .base_backbone import BaseBackbone + + +class T2TTransformerLayer(BaseModule): + """Transformer Layer for T2T_ViT. + + Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports + different ``input_dims`` and ``embed_dims``. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs + input_dims (int, optional): The input token dimension. + Defaults to None. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``(input_dims // num_heads) ** -0.5`` if set. Defaults to None. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + + Notes: + In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e. + ``(embed_dims // num_heads) ** -0.5``. However, in the official + code, it uses ``(input_dims // num_heads) ** -0.5``, so here we + keep the same with the official implementation. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + input_dims=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=False, + qk_scale=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg) + + self.v_shortcut = True if input_dims is not None else False + input_dims = input_dims or embed_dims + + self.ln1 = build_norm_layer(norm_cfg, input_dims) + + self.attn = MultiheadAttention( + input_dims=input_dims, + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + qk_scale=qk_scale or (input_dims // num_heads)**-0.5, + v_shortcut=self.v_shortcut) + + self.ln2 = build_norm_layer(norm_cfg, embed_dims) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + def forward(self, x): + if self.v_shortcut: + x = self.attn(self.ln1(x)) + else: + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) + return x + + +class T2TModule(BaseModule): + """Tokens-to-Token module. + + "Tokens-to-Token module" (T2T Module) can model the local structure + information of images and reduce the length of tokens progressively. + + Args: + img_size (int): Input image size + in_channels (int): Number of input channels + embed_dims (int): Embedding dimension + token_dims (int): Tokens dimension in T2TModuleAttention. + use_performer (bool): If True, use Performer version self-attention to + adopt regular self-attention. Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Default: None. + + Notes: + Usually, ``token_dim`` is set as a small value (32 or 64) to reduce + MACs + """ + + def __init__( + self, + img_size=224, + in_channels=3, + embed_dims=384, + token_dims=64, + use_performer=False, + init_cfg=None, + ): + super(T2TModule, self).__init__(init_cfg) + + self.embed_dims = embed_dims + + self.soft_split0 = nn.Unfold( + kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) + self.soft_split1 = nn.Unfold( + kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.soft_split2 = nn.Unfold( + kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + + if not use_performer: + self.attention1 = T2TTransformerLayer( + input_dims=in_channels * 7 * 7, + embed_dims=token_dims, + num_heads=1, + feedforward_channels=token_dims) + + self.attention2 = T2TTransformerLayer( + input_dims=token_dims * 3 * 3, + embed_dims=token_dims, + num_heads=1, + feedforward_channels=token_dims) + + self.project = nn.Linear(token_dims * 3 * 3, embed_dims) + else: + raise NotImplementedError("Performer hasn't been implemented.") + + # there are 3 soft split, stride are 4,2,2 separately + out_side = img_size // (4 * 2 * 2) + self.init_out_size = [out_side, out_side] + self.num_patches = out_side**2 + + @staticmethod + def _get_unfold_size(unfold: nn.Unfold, input_size): + h, w = input_size + kernel_size = to_2tuple(unfold.kernel_size) + stride = to_2tuple(unfold.stride) + padding = to_2tuple(unfold.padding) + dilation = to_2tuple(unfold.dilation) + + h_out = (h + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (w + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + return (h_out, w_out) + + def forward(self, x): + # step0: soft split + hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:]) + x = self.soft_split0(x).transpose(1, 2) + + for step in [1, 2]: + # re-structurization/reconstruction + attn = getattr(self, f'attention{step}') + x = attn(x).transpose(1, 2) + B, C, _ = x.shape + x = x.reshape(B, C, hw_shape[0], hw_shape[1]) + + # soft split + soft_split = getattr(self, f'soft_split{step}') + hw_shape = self._get_unfold_size(soft_split, hw_shape) + x = soft_split(x).transpose(1, 2) + + # final tokens + x = self.project(x) + return x, hw_shape + + +def get_sinusoid_encoding(n_position, embed_dims): + """Generate sinusoid encoding table. + + Sinusoid encoding is a kind of relative position encoding method came from + `Attention Is All You Need`_. + + Args: + n_position (int): The length of the input token. + embed_dims (int): The position embedding dimension. + + Returns: + :obj:`torch.FloatTensor`: The sinusoid encoding table. + """ + + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (i // 2) / embed_dims) + for i in range(embed_dims) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos) for pos in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +@MODELS.register_module() +class T2T_ViT(BaseBackbone): + """Tokens-to-Token Vision Transformer (T2T-ViT) + + A PyTorch implementation of `Tokens-to-Token ViT: Training Vision + Transformers from Scratch on ImageNet `_ + + Args: + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + in_channels (int): Number of input channels. + embed_dims (int): Embedding dimension. + num_layers (int): Num of transformer layers in encoder. + Defaults to 14. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Dropout rate after position embedding. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. Defaults to + ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + t2t_cfg (dict): Extra config of Tokens-to-Token module. + Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + img_size=224, + in_channels=3, + embed_dims=384, + num_layers=14, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + final_norm=True, + out_type='cls_token', + with_cls_token=True, + interpolate_mode='bicubic', + t2t_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super().__init__(init_cfg) + + # Token-to-Token Module + self.tokens_to_token = T2TModule( + img_size=img_size, + in_channels=in_channels, + embed_dims=embed_dims, + **t2t_cfg) + self.patch_resolution = self.tokens_to_token.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + sinusoid_table = get_sinusoid_encoding( + num_patches + self.num_extra_tokens, embed_dims) + self.register_buffer('pos_embed', sinusoid_table) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must be a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = num_layers + index + assert 0 <= out_indices[i] <= num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)] + + self.encoder = ModuleList() + for i in range(num_layers): + if isinstance(layer_cfgs, Sequence): + layer_cfg = layer_cfgs[i] + else: + layer_cfg = deepcopy(layer_cfgs) + layer_cfg = { + 'embed_dims': embed_dims, + 'num_heads': 6, + 'feedforward_channels': 3 * embed_dims, + 'drop_path_rate': dpr[i], + 'qkv_bias': False, + 'norm_cfg': norm_cfg, + **layer_cfg + } + + layer = T2TTransformerLayer(**layer_cfg) + self.encoder.append(layer) + + self.final_norm = final_norm + if final_norm: + self.norm = build_norm_layer(norm_cfg, embed_dims) + else: + self.norm = nn.Identity() + + def init_weights(self): + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress custom init if use pretrained model. + return + + trunc_normal_(self.cls_token, std=.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.tokens_to_token.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.tokens_to_token(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, layer in enumerate(self.encoder): + x = layer(x) + + if i == len(self.encoder) - 1 and self.final_norm: + x = self.norm(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) diff --git a/mmpretrain/models/backbones/timm_backbone.py b/mmpretrain/models/backbones/timm_backbone.py new file mode 100644 index 0000000..51ecbdb --- /dev/null +++ b/mmpretrain/models/backbones/timm_backbone.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmengine.logging import MMLogger + +from mmpretrain.registry import MODELS +from mmpretrain.utils import require +from .base_backbone import BaseBackbone + + +def print_timm_feature_info(feature_info): + """Print feature_info of timm backbone to help development and debug. + + Args: + feature_info (list[dict] | timm.models.features.FeatureInfo | None): + feature_info of timm backbone. + """ + logger = MMLogger.get_current_instance() + if feature_info is None: + logger.warning('This backbone does not have feature_info') + elif isinstance(feature_info, list): + for feat_idx, each_info in enumerate(feature_info): + logger.info(f'backbone feature_info[{feat_idx}]: {each_info}') + else: + try: + logger.info(f'backbone out_indices: {feature_info.out_indices}') + logger.info(f'backbone out_channels: {feature_info.channels()}') + logger.info(f'backbone out_strides: {feature_info.reduction()}') + except AttributeError: + logger.warning('Unexpected format of backbone feature_info') + + +@MODELS.register_module() +class TIMMBackbone(BaseBackbone): + """Wrapper to use backbones from timm library. + + More details can be found in + `timm `_. + See especially the document for `feature extraction + `_. + + Args: + model_name (str): Name of timm model to instantiate. + features_only (bool): Whether to extract feature pyramid (multi-scale + feature maps from the deepest layer at each stride). For Vision + Transformer models that do not support this argument, + set this False. Defaults to False. + pretrained (bool): Whether to load pretrained weights. + Defaults to False. + checkpoint_path (str): Path of checkpoint to load at the last of + ``timm.create_model``. Defaults to empty string, which means + not loading. + in_channels (int): Number of input image channels. Defaults to 3. + init_cfg (dict or list[dict], optional): Initialization config dict of + OpenMMLab projects. Defaults to None. + **kwargs: Other timm & model specific arguments. + """ + + @require('timm') + def __init__(self, + model_name, + features_only=False, + pretrained=False, + checkpoint_path='', + in_channels=3, + init_cfg=None, + **kwargs): + import timm + + if not isinstance(pretrained, bool): + raise TypeError('pretrained must be bool, not str for model path') + if features_only and checkpoint_path: + warnings.warn( + 'Using both features_only and checkpoint_path will cause error' + ' in timm. See ' + 'https://github.com/rwightman/pytorch-image-models/issues/488') + + super(TIMMBackbone, self).__init__(init_cfg) + if 'norm_layer' in kwargs: + norm_class = MODELS.get(kwargs['norm_layer']) + + def build_norm(*args, **kwargs): + return norm_class(*args, **kwargs) + + kwargs['norm_layer'] = build_norm + self.timm_model = timm.create_model( + model_name=model_name, + features_only=features_only, + pretrained=pretrained, + in_chans=in_channels, + checkpoint_path=checkpoint_path, + **kwargs) + + # reset classifier + if hasattr(self.timm_model, 'reset_classifier'): + self.timm_model.reset_classifier(0, '') + + # Hack to use pretrained weights from timm + if pretrained or checkpoint_path: + self._is_init = True + + feature_info = getattr(self.timm_model, 'feature_info', None) + print_timm_feature_info(feature_info) + + def forward(self, x): + features = self.timm_model(x) + if isinstance(features, (list, tuple)): + features = tuple(features) + else: + features = (features, ) + return features diff --git a/mmpretrain/models/backbones/tinyvit.py b/mmpretrain/models/backbones/tinyvit.py new file mode 100644 index 0000000..5279832 --- /dev/null +++ b/mmpretrain/models/backbones/tinyvit.py @@ -0,0 +1,769 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Tuple + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from torch.nn import functional as F + +from mmpretrain.registry import MODELS +from ..utils import LeAttention +from .base_backbone import BaseBackbone + + +class ConvBN2d(Sequential): + """An implementation of Conv2d + BatchNorm2d with support of fusion. + + Modified from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_size (int): The size of the convolution kernel. + Default: 1. + stride (int): The stride of the convolution. + Default: 1. + padding (int): The padding of the convolution. + Default: 0. + dilation (int): The dilation of the convolution. + Default: 1. + groups (int): The number of groups in the convolution. + Default: 1. + bn_weight_init (float): The initial value of the weight of + the nn.BatchNorm2d layer. Default: 1.0. + init_cfg (dict): The initialization config of the module. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bn_weight_init=1.0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.add_module( + 'conv2d', + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False)) + bn2d = nn.BatchNorm2d(num_features=out_channels) + # bn initialization + torch.nn.init.constant_(bn2d.weight, bn_weight_init) + torch.nn.init.constant_(bn2d.bias, 0) + self.add_module('bn2d', bn2d) + + @torch.no_grad() + def fuse(self): + conv2d, bn2d = self._modules.values() + w = bn2d.weight / (bn2d.running_var + bn2d.eps)**0.5 + w = conv2d.weight * w[:, None, None, None] + b = bn2d.bias - bn2d.running_mean * bn2d.weight / \ + (bn2d.running_var + bn2d.eps)**0.5 + + m = nn.Conv2d( + in_channels=w.size(1) * self.c.groups, + out_channels=w.size(0), + kernel_size=w.shape[2:], + stride=self.conv2d.stride, + padding=self.conv2d.padding, + dilation=self.conv2d.dilation, + groups=self.conv2d.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class PatchEmbed(BaseModule): + """Patch Embedding for Vision Transformer. + + Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Different from `mmcv.cnn.bricks.transformer.PatchEmbed`, this module use + Conv2d and BatchNorm2d to implement PatchEmbedding, and output shape is + (N, C, H, W). + + Args: + in_channels (int): The number of input channels. + embed_dim (int): The embedding dimension. + resolution (Tuple[int, int]): The resolution of the input feature. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + embed_dim, + resolution, + act_cfg=dict(type='GELU')): + super().__init__() + img_size: Tuple[int, int] = resolution + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * \ + self.patches_resolution[1] + self.in_channels = in_channels + self.embed_dim = embed_dim + self.seq = nn.Sequential( + ConvBN2d( + in_channels, + embed_dim // 2, + kernel_size=3, + stride=2, + padding=1), + build_activation_layer(act_cfg), + ConvBN2d( + embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), + ) + + def forward(self, x): + return self.seq(x) + + +class PatchMerging(nn.Module): + """Patch Merging for TinyViT. + + Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Different from `mmpretrain.models.utils.PatchMerging`, this module use + Conv2d and BatchNorm2d to implement PatchMerging. + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + out_channels (int): The number of output channels. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + resolution, + in_channels, + out_channels, + act_cfg=dict(type='GELU')): + super().__init__() + + self.img_size = resolution + + self.act = build_activation_layer(act_cfg) + self.conv1 = ConvBN2d(in_channels, out_channels, kernel_size=1) + self.conv2 = ConvBN2d( + out_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + groups=out_channels) + self.conv3 = ConvBN2d(out_channels, out_channels, kernel_size=1) + self.out_resolution = (resolution[0] // 2, resolution[1] // 2) + + def forward(self, x): + if len(x.shape) == 3: + H, W = self.img_size + B = x.shape[0] + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + x = self.conv1(x) + x = self.act(x) + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + + x = x.flatten(2).transpose(1, 2) + return x + + +class MBConvBlock(nn.Module): + """Mobile Inverted Residual Bottleneck Block for TinyViT. Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + expand_ratio (int): The expand ratio of the hidden channels. + drop_rate (float): The drop rate of the block. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + out_channels, + expand_ratio, + drop_path, + act_cfg=dict(type='GELU')): + super().__init__() + self.in_channels = in_channels + hidden_channels = int(in_channels * expand_ratio) + + # linear + self.conv1 = ConvBN2d(in_channels, hidden_channels, kernel_size=1) + self.act = build_activation_layer(act_cfg) + # depthwise conv + self.conv2 = ConvBN2d( + in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_channels) + # linear + self.conv3 = ConvBN2d( + hidden_channels, out_channels, kernel_size=1, bn_weight_init=0.0) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act(x) + + return x + + +class ConvStage(BaseModule): + """Convolution Stage for TinyViT. + + Adapted from + https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + depth (int): The number of blocks in the stage. + act_cfg (dict): The activation config of the module. + drop_path (float): The drop path of the block. + downsample (None | nn.Module): The downsample operation. + Default: None. + use_checkpoint (bool): Whether to use checkpointing to save memory. + out_channels (int): The number of output channels. + conv_expand_ratio (int): The expand ratio of the hidden channels. + Default: 4. + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + resolution, + depth, + act_cfg, + drop_path=0., + downsample=None, + use_checkpoint=False, + out_channels=None, + conv_expand_ratio=4., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.use_checkpoint = use_checkpoint + # build blocks + self.blocks = ModuleList([ + MBConvBlock( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=conv_expand_ratio, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path) + for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + resolution=resolution, + in_channels=in_channels, + out_channels=out_channels, + act_cfg=act_cfg) + self.resolution = self.downsample.out_resolution + else: + self.downsample = None + self.resolution = resolution + + def forward(self, x): + for block in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(block, x) + else: + x = block(x) + + if self.downsample is not None: + x = self.downsample(x) + return x + + +class MLP(BaseModule): + """MLP module for TinyViT. + + Args: + in_channels (int): The number of input channels. + hidden_channels (int, optional): The number of hidden channels. + Default: None. + out_channels (int, optional): The number of output channels. + Default: None. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + drop (float): Probability of an element to be zeroed. + Default: 0. + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + hidden_channels=None, + out_channels=None, + act_cfg=dict(type='GELU'), + drop=0., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.norm = nn.LayerNorm(in_channels) + self.fc1 = nn.Linear(in_channels, hidden_channels) + self.fc2 = nn.Linear(hidden_channels, out_channels) + self.act = build_activation_layer(act_cfg) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class TinyViTBlock(BaseModule): + """TinViT Block. + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + num_heads (int): The number of heads in the multi-head attention. + window_size (int): The size of the window. + Default: 7. + mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. + Default: 4. + drop (float): Probability of an element to be zeroed. + Default: 0. + drop_path (float): The drop path of the block. + Default: 0. + local_conv_size (int): The size of the local convolution. + Default: 3. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + resolution, + num_heads, + window_size=7, + mlp_ratio=4., + drop=0., + drop_path=0., + local_conv_size=3, + act_cfg=dict(type='GELU')): + super().__init__() + self.in_channels = in_channels + self.img_size = resolution + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + assert in_channels % num_heads == 0, \ + 'dim must be divisible by num_heads' + head_dim = in_channels // num_heads + + window_resolution = (window_size, window_size) + self.attn = LeAttention( + in_channels, + head_dim, + num_heads, + attn_ratio=1, + resolution=window_resolution) + + mlp_hidden_dim = int(in_channels * mlp_ratio) + self.mlp = MLP( + in_channels=in_channels, + hidden_channels=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop) + + self.local_conv = ConvBN2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=local_conv_size, + stride=1, + padding=local_conv_size // 2, + groups=in_channels) + + def forward(self, x): + H, W = self.img_size + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - + H % self.window_size) % self.window_size + pad_r = (self.window_size - + W % self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = x.view(B, nH, self.window_size, nW, self.window_size, + C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C) + x = self.attn(x) + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, + C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + +class BasicStage(BaseModule): + """Basic Stage for TinyViT. + + Args: + in_channels (int): The number of input channels. + resolution (Tuple[int, int]): The resolution of the input feature. + depth (int): The number of blocks in the stage. + num_heads (int): The number of heads in the multi-head attention. + window_size (int): The size of the window. + mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. + Default: 4. + drop (float): Probability of an element to be zeroed. + Default: 0. + drop_path (float): The drop path of the block. + Default: 0. + downsample (None | nn.Module): The downsample operation. + Default: None. + use_checkpoint (bool): Whether to use checkpointing to save memory. + Default: False. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + drop=0., + drop_path=0., + downsample=None, + use_checkpoint=False, + local_conv_size=3, + out_channels=None, + act_cfg=dict(type='GELU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.use_checkpoint = use_checkpoint + # build blocks + self.blocks = ModuleList([ + TinyViTBlock( + in_channels=in_channels, + resolution=resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + local_conv_size=local_conv_size, + act_cfg=act_cfg, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path) + for i in range(depth) + ]) + + # build patch merging layer + if downsample is not None: + self.downsample = downsample( + resolution=resolution, + in_channels=in_channels, + out_channels=out_channels, + act_cfg=act_cfg) + self.resolution = self.downsample.out_resolution + else: + self.downsample = None + self.resolution = resolution + + def forward(self, x): + for block in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(block, x) + else: + x = block(x) + + if self.downsample is not None: + x = self.downsample(x) + return x + + +@MODELS.register_module() +class TinyViT(BaseBackbone): + """TinyViT. + A PyTorch implementation of : `TinyViT: Fast Pretraining Distillation + for Small Vision Transformers`_ + + Inspiration from + https://github.com/microsoft/Cream/blob/main/TinyViT + + Args: + arch (str | dict): The architecture of TinyViT. + Default: '5m'. + img_size (tuple | int): The resolution of the input image. + Default: (224, 224) + window_size (list): The size of the window. + Default: [7, 7, 14, 7] + in_channels (int): The number of input channels. + Default: 3. + depths (list[int]): The depth of each stage. + Default: [2, 2, 6, 2]. + mlp_ratio (list[int]): The ratio of mlp hidden dim to embedding dim. + Default: 4. + drop_rate (float): Probability of an element to be zeroed. + Default: 0. + drop_path_rate (float): The drop path of the block. + Default: 0.1. + use_checkpoint (bool): Whether to use checkpointing to save memory. + Default: False. + mbconv_expand_ratio (int): The expand ratio of the mbconv. + Default: 4.0 + local_conv_size (int): The size of the local conv. + Default: 3. + layer_lr_decay (float): The layer lr decay. + Default: 1.0 + out_indices (int | list[int]): Output from which stages. + Default: -1 + frozen_stages (int | list[int]): Stages to be frozen (all param fixed). + Default: -0 + gap_before_final_nrom (bool): Whether to add a gap before the final + norm. Default: True. + act_cfg (dict): The activation config of the module. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict | list[dict], optional): Initialization config dict. + Default: None. + """ + arch_settings = { + '5m': { + 'channels': [64, 128, 160, 320], + 'num_heads': [2, 4, 5, 10], + 'depths': [2, 2, 6, 2], + }, + '11m': { + 'channels': [64, 128, 256, 448], + 'num_heads': [2, 4, 8, 14], + 'depths': [2, 2, 6, 2], + }, + '21m': { + 'channels': [96, 192, 384, 576], + 'num_heads': [3, 6, 12, 18], + 'depths': [2, 2, 6, 2], + }, + } + + def __init__(self, + arch='5m', + img_size=(224, 224), + window_size=[7, 7, 14, 7], + in_channels=3, + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + out_indices=-1, + frozen_stages=0, + gap_before_final_norm=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavaiable arch, please choose from ' \ + f'({set(self.arch_settings)} or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'channels' in arch and 'num_heads' in arch and \ + 'depths' in arch, 'The arch dict must have' \ + f'"channels", "num_heads", "window_sizes" ' \ + f'keys, but got {arch.keys()}' + + self.channels = arch['channels'] + self.num_heads = arch['num_heads'] + self.widow_sizes = window_size + self.img_size = img_size + self.depths = arch['depths'] + + self.num_stages = len(self.channels) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_final_norm = gap_before_final_norm + self.layer_lr_decay = layer_lr_decay + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dim=self.channels[0], + resolution=self.img_size, + act_cfg=dict(type='GELU')) + patches_resolution = self.patch_embed.patches_resolution + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + + # build stages + self.stages = ModuleList() + for i in range(self.num_stages): + depth = self.depths[i] + channel = self.channels[i] + curr_resolution = (patches_resolution[0] // (2**i), + patches_resolution[1] // (2**i)) + drop_path = dpr[sum(self.depths[:i]):sum(self.depths[:i + 1])] + downsample = PatchMerging if (i < self.num_stages - 1) else None + out_channels = self.channels[min(i + 1, self.num_stages - 1)] + if i >= 1: + stage = BasicStage( + in_channels=channel, + resolution=curr_resolution, + depth=depth, + num_heads=self.num_heads[i], + window_size=self.widow_sizes[i], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=drop_path, + downsample=downsample, + use_checkpoint=use_checkpoint, + local_conv_size=local_conv_size, + out_channels=out_channels, + act_cfg=act_cfg) + else: + stage = ConvStage( + in_channels=channel, + resolution=curr_resolution, + depth=depth, + act_cfg=act_cfg, + drop_path=drop_path, + downsample=downsample, + use_checkpoint=use_checkpoint, + out_channels=out_channels, + conv_expand_ratio=mbconv_expand_ratio) + self.stages.append(stage) + + # add output norm + if i in self.out_indices: + norm_layer = build_norm_layer(norm_cfg, out_channels)[1] + self.add_module(f'norm{i}', norm_layer) + + def set_layer_lr_decay(self, layer_lr_decay): + # TODO: add layer_lr_decay + pass + + def forward(self, x): + outs = [] + x = self.patch_embed(x) + + for i, stage in enumerate(self.stages): + x = stage(x) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + if self.gap_before_final_norm: + gap = x.mean(1) + outs.append(norm_layer(gap)) + else: + out = norm_layer(x) + # convert the (B,L,C) format into (B,C,H,W) format + # which would be better for the downstream tasks. + B, L, C = out.shape + out = out.view(B, *stage.resolution, C) + outs.append(out.permute(0, 3, 1, 2)) + + return tuple(outs) + + def _freeze_stages(self): + for i in range(self.frozen_stages): + stage = self.stages[i] + stage.eval() + for param in stage.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(TinyViT, self).train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/backbones/tnt.py b/mmpretrain/models/backbones/tnt.py new file mode 100644 index 0000000..e1b241c --- /dev/null +++ b/mmpretrain/models/backbones/tnt.py @@ -0,0 +1,368 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import to_2tuple +from .base_backbone import BaseBackbone + + +class TransformerBlock(BaseModule): + """Implement a transformer block in TnTLayer. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer. + Default: 4 + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default 0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0. + drop_path_rate (float): stochastic depth rate. Default 0. + num_fcs (int): The number of fully-connected layers for FFNs. Default 2 + qkv_bias (bool): Enable bias for qkv if True. Default False + act_cfg (dict): The activation config for FFNs. Defaults to GELU. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) or (n, batch, embed_dim). + (batch, n, embed_dim) is common case in CV. Defaults to False + init_cfg (dict, optional): Initialization config dict. Defaults to None + """ + + def __init__(self, + embed_dims, + num_heads, + ffn_ratio=4, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + init_cfg=None): + super(TransformerBlock, self).__init__(init_cfg=init_cfg) + + self.norm_attn = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + batch_first=batch_first) + + self.norm_ffn = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=embed_dims * ffn_ratio, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + if not qkv_bias: + self.attn.attn.in_proj_bias = None + + def forward(self, x): + x = self.attn(self.norm_attn(x), identity=x) + x = self.ffn(self.norm_ffn(x), identity=x) + return x + + +class TnTLayer(BaseModule): + """Implement one encoder layer in Transformer in Transformer. + + Args: + num_pixel (int): The pixel number in target patch transformed with + a linear projection in inner transformer + embed_dims_inner (int): Feature dimension in inner transformer block + embed_dims_outer (int): Feature dimension in outer transformer block + num_heads_inner (int): Parallel attention heads in inner transformer. + num_heads_outer (int): Parallel attention heads in outer transformer. + inner_block_cfg (dict): Extra config of inner transformer block. + Defaults to empty dict. + outer_block_cfg (dict): Extra config of outer transformer block. + Defaults to empty dict. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization + init_cfg (dict, optional): Initialization config dict. Defaults to None + """ + + def __init__(self, + num_pixel, + embed_dims_inner, + embed_dims_outer, + num_heads_inner, + num_heads_outer, + inner_block_cfg=dict(), + outer_block_cfg=dict(), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(TnTLayer, self).__init__(init_cfg=init_cfg) + + self.inner_block = TransformerBlock( + embed_dims=embed_dims_inner, + num_heads=num_heads_inner, + **inner_block_cfg) + + self.norm_proj = build_norm_layer(norm_cfg, embed_dims_inner)[1] + self.projection = nn.Linear( + embed_dims_inner * num_pixel, embed_dims_outer, bias=True) + + self.outer_block = TransformerBlock( + embed_dims=embed_dims_outer, + num_heads=num_heads_outer, + **outer_block_cfg) + + def forward(self, pixel_embed, patch_embed): + pixel_embed = self.inner_block(pixel_embed) + + B, N, C = patch_embed.size() + patch_embed[:, 1:] = patch_embed[:, 1:] + self.projection( + self.norm_proj(pixel_embed).reshape(B, N - 1, -1)) + patch_embed = self.outer_block(patch_embed) + + return pixel_embed, patch_embed + + +class PixelEmbed(BaseModule): + """Image to Pixel Embedding. + + Args: + img_size (int | tuple): The size of input image + patch_size (int): The size of one patch + in_channels (int): The num of input channels + embed_dims_inner (int): The num of channels of the target patch + transformed with a linear projection in inner transformer + stride (int): The stride of the conv2d layer. We use a conv2d layer + and a unfold layer to implement image to pixel embedding. + init_cfg (dict, optional): Initialization config dict + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims_inner=48, + stride=4, + init_cfg=None): + super(PixelEmbed, self).__init__(init_cfg=init_cfg) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + # patches_resolution property necessary for resizing + # positional embedding + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] + num_patches = patches_resolution[0] * patches_resolution[1] + + self.img_size = img_size + self.num_patches = num_patches + self.embed_dims_inner = embed_dims_inner + + new_patch_size = [math.ceil(ps / stride) for ps in patch_size] + self.new_patch_size = new_patch_size + + self.proj = nn.Conv2d( + in_channels, + self.embed_dims_inner, + kernel_size=7, + padding=3, + stride=stride) + self.unfold = nn.Unfold( + kernel_size=new_patch_size, stride=new_patch_size) + + def forward(self, x, pixel_pos): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model " \ + f'({self.img_size[0]}*{self.img_size[1]}).' + x = self.proj(x) + x = self.unfold(x) + x = x.transpose(1, + 2).reshape(B * self.num_patches, self.embed_dims_inner, + self.new_patch_size[0], + self.new_patch_size[1]) + x = x + pixel_pos + x = x.reshape(B * self.num_patches, self.embed_dims_inner, + -1).transpose(1, 2) + return x + + +@MODELS.register_module() +class TNT(BaseBackbone): + """Transformer in Transformer. + + A PyTorch implement of: `Transformer in Transformer + `_ + + Inspiration from + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/tnt.py + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size. Defaults to 224 + patch_size (int | tuple): The patch size. Deault to 16 + in_channels (int): Number of input channels. Defaults to 3 + ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer. + Default: 4 + qkv_bias (bool): Enable bias for qkv if True. Default False + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default 0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0. + drop_path_rate (float): stochastic depth rate. Default 0. + act_cfg (dict): The activation config for FFNs. Defaults to GELU. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization + first_stride (int): The stride of the conv2d layer. We use a conv2d + layer and a unfold layer to implement image to pixel embedding. + num_fcs (int): The number of fully-connected layers for FFNs. Default 2 + init_cfg (dict, optional): Initialization config dict + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims_outer': 384, + 'embed_dims_inner': 24, + 'num_layers': 12, + 'num_heads_outer': 6, + 'num_heads_inner': 4 + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims_outer': 640, + 'embed_dims_inner': 40, + 'num_layers': 12, + 'num_heads_outer': 10, + 'num_heads_inner': 4 + }) + } + + def __init__(self, + arch='b', + img_size=224, + patch_size=16, + in_channels=3, + ffn_ratio=4, + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + first_stride=4, + num_fcs=2, + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.) + ]): + super(TNT, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims_outer', 'embed_dims_inner', 'num_layers', + 'num_heads_inner', 'num_heads_outer' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims_inner = self.arch_settings['embed_dims_inner'] + self.embed_dims_outer = self.arch_settings['embed_dims_outer'] + # embed_dims for consistency with other models + self.embed_dims = self.embed_dims_outer + self.num_layers = self.arch_settings['num_layers'] + self.num_heads_inner = self.arch_settings['num_heads_inner'] + self.num_heads_outer = self.arch_settings['num_heads_outer'] + + self.pixel_embed = PixelEmbed( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dims_inner=self.embed_dims_inner, + stride=first_stride) + num_patches = self.pixel_embed.num_patches + self.num_patches = num_patches + new_patch_size = self.pixel_embed.new_patch_size + num_pixel = new_patch_size[0] * new_patch_size[1] + + self.norm1_proj = build_norm_layer(norm_cfg, num_pixel * + self.embed_dims_inner)[1] + self.projection = nn.Linear(num_pixel * self.embed_dims_inner, + self.embed_dims_outer) + self.norm2_proj = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] + + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims_outer)) + self.patch_pos = nn.Parameter( + torch.zeros(1, num_patches + 1, self.embed_dims_outer)) + self.pixel_pos = nn.Parameter( + torch.zeros(1, self.embed_dims_inner, new_patch_size[0], + new_patch_size[1])) + self.drop_after_pos = nn.Dropout(p=drop_rate) + + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, self.num_layers) + ] # stochastic depth decay rule + self.layers = ModuleList() + for i in range(self.num_layers): + block_cfg = dict( + ffn_ratio=ffn_ratio, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + batch_first=True) + self.layers.append( + TnTLayer( + num_pixel=num_pixel, + embed_dims_inner=self.embed_dims_inner, + embed_dims_outer=self.embed_dims_outer, + num_heads_inner=self.num_heads_inner, + num_heads_outer=self.num_heads_outer, + inner_block_cfg=block_cfg, + outer_block_cfg=block_cfg, + norm_cfg=norm_cfg)) + + self.norm = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] + + trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.patch_pos, std=.02) + trunc_normal_(self.pixel_pos, std=.02) + + def forward(self, x): + B = x.shape[0] + pixel_embed = self.pixel_embed(x, self.pixel_pos) + + patch_embed = self.norm2_proj( + self.projection( + self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) + patch_embed = torch.cat( + (self.cls_token.expand(B, -1, -1), patch_embed), dim=1) + patch_embed = patch_embed + self.patch_pos + patch_embed = self.drop_after_pos(patch_embed) + + for layer in self.layers: + pixel_embed, patch_embed = layer(pixel_embed, patch_embed) + + patch_embed = self.norm(patch_embed) + return (patch_embed[:, 0], ) diff --git a/mmpretrain/models/backbones/twins.py b/mmpretrain/models/backbones/twins.py new file mode 100644 index 0000000..be55c02 --- /dev/null +++ b/mmpretrain/models/backbones/twins.py @@ -0,0 +1,721 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from ..utils import ConditionalPositionEncoding, MultiheadAttention + + +class GlobalSubsampledAttention(MultiheadAttention): + """Global Sub-sampled Attention (GSA) module. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shortcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + sr_ratio (float): The ratio of spatial reduction in attention modules. + Defaults to 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + norm_cfg=dict(type='LN'), + qkv_bias=True, + sr_ratio=1, + **kwargs): + super(GlobalSubsampledAttention, + self).__init__(embed_dims, num_heads, **kwargs) + + self.qkv_bias = qkv_bias + self.q = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias) + self.kv = nn.Linear(self.input_dims, embed_dims * 2, bias=qkv_bias) + + # remove self.qkv, here split into self.q, self.kv + delattr(self, 'qkv') + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + # use a conv as the spatial-reduction operation, the kernel_size + # and stride in conv are equal to the sr_ratio. + self.sr = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=sr_ratio, + stride=sr_ratio) + # The ret[0] of build_norm_layer is norm name. + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + def forward(self, x, hw_shape): + B, N, C = x.shape + H, W = hw_shape + assert H * W == N, 'The product of h and w of hw_shape must be N, ' \ + 'which is the 2nd dim number of the input Tensor x.' + + q = self.q(x).reshape(B, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + x = x.permute(0, 2, 1).reshape(B, C, *hw_shape) # BNC_2_BCHW + x = self.sr(x) + x = x.reshape(B, C, -1).permute(0, 2, 1) # BCHW_2_BNC + x = self.norm(x) + + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn_drop = self.attn_drop if self.training else 0. + x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) + + x = self.proj(x) + x = self.out_drop(self.proj_drop(x)) + + if self.v_shortcut: + x = v.squeeze(1) + x + return x + + +class GSAEncoderLayer(BaseModule): + """Implements one encoder layer with GlobalSubsampledAttention(GSA). + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (float): The ratio of spatial reduction in attention modules. + Defaults to 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=1., + init_cfg=None): + super(GSAEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = GlobalSubsampledAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +class LocallyGroupedSelfAttention(BaseModule): + """Locally-grouped Self Attention (LSA) module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + window_size(int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + window_size=1, + init_cfg=None): + super(LocallyGroupedSelfAttention, self).__init__(init_cfg=init_cfg) + + assert embed_dims % num_heads == 0, \ + f'dim {embed_dims} should be divided by num_heads {num_heads}' + + self.embed_dims = embed_dims + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + self.window_size = window_size + + def forward(self, x, hw_shape): + B, N, C = x.shape + H, W = hw_shape + x = x.view(B, H, W, C) + + # pad feature maps to multiples of Local-groups + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + # calculate attention mask for LSA + Hp, Wp = x.shape[1:-1] + _h, _w = Hp // self.window_size, Wp // self.window_size + mask = torch.zeros((1, Hp, Wp), device=x.device) + mask[:, -pad_b:, :].fill_(1) + mask[:, :, -pad_r:].fill_(1) + + # [B, _h, _w, window_size, window_size, C] + x = x.reshape(B, _h, self.window_size, _w, self.window_size, + C).transpose(2, 3) + mask = mask.reshape(1, _h, self.window_size, _w, + self.window_size).transpose(2, 3).reshape( + 1, _h * _w, + self.window_size * self.window_size) + # [1, _h*_w, window_size*window_size, window_size*window_size] + attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-1000.0)).masked_fill( + attn_mask == 0, float(0.0)) + + # [3, B, _w*_h, nhead, window_size*window_size, dim] + qkv = self.qkv(x).reshape(B, _h * _w, + self.window_size * self.window_size, 3, + self.num_heads, C // self.num_heads).permute( + 3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + # [B, _h*_w, n_head, window_size*window_size, window_size*window_size] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn + attn_mask.unsqueeze(2) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.window_size, + self.window_size, C) + x = attn.transpose(2, 3).reshape(B, _h * self.window_size, + _w * self.window_size, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LSAEncoderLayer(BaseModule): + """Implements one encoder layer with LocallyGroupedSelfAttention(LSA). + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + window_size (int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + qk_scale=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + window_size=1, + init_cfg=None): + + super(LSAEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads, + qkv_bias, qk_scale, + attn_drop_rate, drop_rate, + window_size) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +@MODELS.register_module() +class PCPVT(BaseModule): + """The backbone of Twins-PCPVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + arch (dict, str): PCPVT architecture, a str value in arch zoo or a + detailed configuration dict with 7 keys, and the length of all the + values in dict should be the same: + + - depths (List[int]): The number of encoder layers in each stage. + - embed_dims (List[int]): Embedding dimension in each stage. + - patch_sizes (List[int]): The patch sizes in each stage. + - num_heads (List[int]): Numbers of attention head in each stage. + - strides (List[int]): The strides in each stage. + - mlp_ratios (List[int]): The ratios of mlp in each stage. + - sr_ratios (List[int]): The ratios of GSA-encoder layers in each + stage. + + in_channels (int): Number of input channels. Defaults to 3. + out_indices (tuple[int]): Output from which stages. + Defaults to ``(3, )``. + qkv_bias (bool): Enable bias for qkv if True. Defaults to False. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): The drop out rate for attention layer. + Defaults to 0.0 + drop_path_rate (float): Stochastic depth rate. Defaults to 0.0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + norm_after_stage(bool, List[bool]): Add extra norm after each stage. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import PCPVT + >>> import torch + >>> pcpvt_cfg = {'arch': "small", + >>> 'norm_after_stage': [False, False, False, True]} + >>> model = PCPVT(**pcpvt_cfg) + >>> x = torch.rand(1, 3, 224, 224) + >>> outputs = model(x) + >>> print(outputs[-1].shape) + torch.Size([1, 512, 7, 7]) + >>> pcpvt_cfg['norm_after_stage'] = [True, True, True, True] + >>> pcpvt_cfg['out_indices'] = (0, 1, 2, 3) + >>> model = PCPVT(**pcpvt_cfg) + >>> outputs = model(x) + >>> for feat in outputs: + >>> print(feat.shape) + torch.Size([1, 64, 56, 56]) + torch.Size([1, 128, 28, 28]) + torch.Size([1, 320, 14, 14]) + torch.Size([1, 512, 7, 7]) + """ + arch_zoo = { + **dict.fromkeys(['s', 'small'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 4, 6, 3], + 'num_heads': [1, 2, 5, 8], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [8, 8, 4, 4], + 'sr_ratios': [8, 4, 2, 1]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 4, 18, 3], + 'num_heads': [1, 2, 5, 8], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [8, 8, 4, 4], + 'sr_ratios': [8, 4, 2, 1]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 8, 27, 3], + 'num_heads': [1, 2, 5, 8], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [8, 8, 4, 4], + 'sr_ratios': [8, 4, 2, 1]}), + } # yapf: disable + + essential_keys = { + 'embed_dims', 'depths', 'num_heads', 'patch_sizes', 'strides', + 'mlp_ratios', 'sr_ratios' + } + + def __init__(self, + arch, + in_channels=3, + out_indices=(3, ), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + norm_after_stage=False, + init_cfg=None): + super(PCPVT, self).__init__(init_cfg=init_cfg) + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + assert isinstance(arch, dict) and ( + set(arch) == self.essential_keys + ), f'Custom arch needs a dict with keys {self.essential_keys}.' + self.arch_settings = arch + + self.depths = self.arch_settings['depths'] + self.embed_dims = self.arch_settings['embed_dims'] + self.patch_sizes = self.arch_settings['patch_sizes'] + self.strides = self.arch_settings['strides'] + self.mlp_ratios = self.arch_settings['mlp_ratios'] + self.num_heads = self.arch_settings['num_heads'] + self.sr_ratios = self.arch_settings['sr_ratios'] + + self.num_extra_tokens = 0 # there is no cls-token in Twins + self.num_stage = len(self.depths) + for key, value in self.arch_settings.items(): + assert isinstance(value, list) and len(value) == self.num_stage, ( + 'Length of setting item in arch dict must be type of list and' + ' have the same length.') + + # patch_embeds + self.patch_embeds = ModuleList() + self.position_encoding_drops = ModuleList() + self.stages = ModuleList() + + for i in range(self.num_stage): + # use in_channels of the model in the first stage + if i == 0: + stage_in_channels = in_channels + else: + stage_in_channels = self.embed_dims[i - 1] + + self.patch_embeds.append( + PatchEmbed( + in_channels=stage_in_channels, + embed_dims=self.embed_dims[i], + conv_type='Conv2d', + kernel_size=self.patch_sizes[i], + stride=self.strides[i], + padding='corner', + norm_cfg=dict(type='LN'))) + + self.position_encoding_drops.append(nn.Dropout(p=drop_rate)) + + # PEGs + self.position_encodings = ModuleList([ + ConditionalPositionEncoding(embed_dim, embed_dim) + for embed_dim in self.embed_dims + ]) + + # stochastic depth + total_depth = sum(self.depths) + self.dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + cur = 0 + + for k in range(len(self.depths)): + _block = ModuleList([ + GSAEncoderLayer( + embed_dims=self.embed_dims[k], + num_heads=self.num_heads[k], + feedforward_channels=self.mlp_ratios[k] * + self.embed_dims[k], + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=self.dpr[cur + i], + num_fcs=2, + qkv_bias=qkv_bias, + act_cfg=dict(type='GELU'), + norm_cfg=norm_cfg, + sr_ratio=self.sr_ratios[k]) for i in range(self.depths[k]) + ]) + self.stages.append(_block) + cur += self.depths[k] + + self.out_indices = out_indices + + assert isinstance(norm_after_stage, (bool, list)) + if isinstance(norm_after_stage, bool): + self.norm_after_stage = [norm_after_stage] * self.num_stage + else: + self.norm_after_stage = norm_after_stage + assert len(self.norm_after_stage) == self.num_stage, \ + (f'Number of norm_after_stage({len(self.norm_after_stage)}) should' + f' be equal to the number of stages({self.num_stage}).') + + for i, has_norm in enumerate(self.norm_after_stage): + assert isinstance(has_norm, bool), 'norm_after_stage should be ' \ + 'bool or List[bool].' + if has_norm and norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, self.embed_dims[i])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm_after_stage{i}', norm_layer) + + def init_weights(self): + if self.init_cfg is not None: + super(PCPVT, self).init_weights() + else: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + + def forward(self, x): + outputs = list() + + b = x.shape[0] + + for i in range(self.num_stage): + x, hw_shape = self.patch_embeds[i](x) + h, w = hw_shape + x = self.position_encoding_drops[i](x) + for j, blk in enumerate(self.stages[i]): + x = blk(x, hw_shape) + if j == 0: + x = self.position_encodings[i](x, hw_shape) + + norm_layer = getattr(self, f'norm_after_stage{i}') + x = norm_layer(x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + + if i in self.out_indices: + outputs.append(x) + + return tuple(outputs) + + +@MODELS.register_module() +class SVT(PCPVT): + """The backbone of Twins-SVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + arch (dict, str): SVT architecture, a str value in arch zoo or a + detailed configuration dict with 8 keys, and the length of all the + values in dict should be the same: + + - depths (List[int]): The number of encoder layers in each stage. + - embed_dims (List[int]): Embedding dimension in each stage. + - patch_sizes (List[int]): The patch sizes in each stage. + - num_heads (List[int]): Numbers of attention head in each stage. + - strides (List[int]): The strides in each stage. + - mlp_ratios (List[int]): The ratios of mlp in each stage. + - sr_ratios (List[int]): The ratios of GSA-encoder layers in each + stage. + - windiow_sizes (List[int]): The window sizes in LSA-encoder layers + in each stage. + + in_channels (int): Number of input channels. Defaults to 3. + out_indices (tuple[int]): Output from which stages. + Defaults to (3, ). + qkv_bias (bool): Enable bias for qkv if True. Defaults to False. + drop_rate (float): Dropout rate. Defaults to 0. + attn_drop_rate (float): Dropout ratio of attention weight. + Defaults to 0.0 + drop_path_rate (float): Stochastic depth rate. Defaults to 0.2. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + norm_after_stage(bool, List[bool]): Add extra norm after each stage. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import SVT + >>> import torch + >>> svt_cfg = {'arch': "small", + >>> 'norm_after_stage': [False, False, False, True]} + >>> model = SVT(**svt_cfg) + >>> x = torch.rand(1, 3, 224, 224) + >>> outputs = model(x) + >>> print(outputs[-1].shape) + torch.Size([1, 512, 7, 7]) + >>> svt_cfg["out_indices"] = (0, 1, 2, 3) + >>> svt_cfg["norm_after_stage"] = [True, True, True, True] + >>> model = SVT(**svt_cfg) + >>> output = model(x) + >>> for feat in output: + >>> print(feat.shape) + torch.Size([1, 64, 56, 56]) + torch.Size([1, 128, 28, 28]) + torch.Size([1, 320, 14, 14]) + torch.Size([1, 512, 7, 7]) + """ + arch_zoo = { + **dict.fromkeys(['s', 'small'], + {'embed_dims': [64, 128, 256, 512], + 'depths': [2, 2, 10, 4], + 'num_heads': [2, 4, 8, 16], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [4, 4, 4, 4], + 'sr_ratios': [8, 4, 2, 1], + 'window_sizes': [7, 7, 7, 7]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': [96, 192, 384, 768], + 'depths': [2, 2, 18, 2], + 'num_heads': [3, 6, 12, 24], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [4, 4, 4, 4], + 'sr_ratios': [8, 4, 2, 1], + 'window_sizes': [7, 7, 7, 7]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': [128, 256, 512, 1024], + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32], + 'patch_sizes': [4, 2, 2, 2], + 'strides': [4, 2, 2, 2], + 'mlp_ratios': [4, 4, 4, 4], + 'sr_ratios': [8, 4, 2, 1], + 'window_sizes': [7, 7, 7, 7]}), + } # yapf: disable + + essential_keys = { + 'embed_dims', 'depths', 'num_heads', 'patch_sizes', 'strides', + 'mlp_ratios', 'sr_ratios', 'window_sizes' + } + + def __init__(self, + arch, + in_channels=3, + out_indices=(3, ), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.0, + norm_cfg=dict(type='LN'), + norm_after_stage=False, + init_cfg=None): + super(SVT, self).__init__(arch, in_channels, out_indices, qkv_bias, + drop_rate, attn_drop_rate, drop_path_rate, + norm_cfg, norm_after_stage, init_cfg) + + self.window_sizes = self.arch_settings['window_sizes'] + + for k in range(self.num_stage): + for i in range(self.depths[k]): + # in even-numbered layers of each stage, replace GSA with LSA + if i % 2 == 0: + ffn_channels = self.mlp_ratios[k] * self.embed_dims[k] + self.stages[k][i] = \ + LSAEncoderLayer( + embed_dims=self.embed_dims[k], + num_heads=self.num_heads[k], + feedforward_channels=ffn_channels, + drop_rate=drop_rate, + norm_cfg=norm_cfg, + attn_drop_rate=attn_drop_rate, + drop_path_rate=self.dpr[sum(self.depths[:k])+i], + qkv_bias=qkv_bias, + window_size=self.window_sizes[k]) diff --git a/mmpretrain/models/backbones/van.py b/mmpretrain/models/backbones/van.py new file mode 100644 index 0000000..c34dc33 --- /dev/null +++ b/mmpretrain/models/backbones/van.py @@ -0,0 +1,434 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +class MixFFN(BaseModule): + """An implementation of MixFFN of VAN. Refer to + mmdetection/mmdet/models/backbones/pvt.py. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Depth-wise Conv to encode positional information. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. + feedforward_channels (int): The hidden dimension of FFNs. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + init_cfg=None): + super(MixFFN, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + + self.fc1 = Conv2d( + in_channels=embed_dims, + out_channels=feedforward_channels, + kernel_size=1) + self.dwconv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=feedforward_channels) + self.act = build_activation_layer(act_cfg) + self.fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=embed_dims, + kernel_size=1) + self.drop = nn.Dropout(ffn_drop) + + def forward(self, x): + x = self.fc1(x) + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class LKA(BaseModule): + """Large Kernel Attention(LKA) of VAN. + + .. code:: text + DW_conv (depth-wise convolution) + | + | + DW_D_conv (depth-wise dilation convolution) + | + | + Transition Convolution (1×1 convolution) + + Args: + embed_dims (int): Number of input channels. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, embed_dims, init_cfg=None): + super(LKA, self).__init__(init_cfg=init_cfg) + + # a spatial local convolution (depth-wise convolution) + self.DW_conv = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=5, + padding=2, + groups=embed_dims) + + # a spatial long-range convolution (depth-wise dilation convolution) + self.DW_D_conv = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=7, + stride=1, + padding=9, + groups=embed_dims, + dilation=3) + + self.conv1 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + + def forward(self, x): + u = x.clone() + attn = self.DW_conv(x) + attn = self.DW_D_conv(attn) + attn = self.conv1(attn) + + return u * attn + + +class SpatialAttention(BaseModule): + """Basic attention module in VANBloack. + + Args: + embed_dims (int): Number of input channels. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None): + super(SpatialAttention, self).__init__(init_cfg=init_cfg) + + self.proj_1 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + self.activation = build_activation_layer(act_cfg) + self.spatial_gating_unit = LKA(embed_dims) + self.proj_2 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + + def forward(self, x): + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class VANBlock(BaseModule): + """A block of VAN. + + Args: + embed_dims (int): Number of input channels. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-2. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + ffn_ratio=4., + drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN', eps=1e-5), + layer_scale_init_value=1e-2, + init_cfg=None): + super(VANBlock, self).__init__(init_cfg=init_cfg) + self.out_channels = embed_dims + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = SpatialAttention(embed_dims, act_cfg=act_cfg) + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + mlp_hidden_dim = int(embed_dims * ffn_ratio) + self.mlp = MixFFN( + embed_dims=embed_dims, + feedforward_channels=mlp_hidden_dim, + act_cfg=act_cfg, + ffn_drop=drop_rate) + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) if layer_scale_init_value > 0 else None + + def forward(self, x): + identity = x + x = self.norm1(x) + x = self.attn(x) + if self.layer_scale_1 is not None: + x = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x + x = identity + self.drop_path(x) + + identity = x + x = self.norm2(x) + x = self.mlp(x) + if self.layer_scale_2 is not None: + x = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * x + x = identity + self.drop_path(x) + + return x + + +class VANPatchEmbed(PatchEmbed): + """Image to Patch Embedding of VAN. + + The differences between VANPatchEmbed & PatchEmbed: + 1. Use BN. + 2. Do not use 'flatten' and 'transpose'. + """ + + def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs): + super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs) + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + Returns: + tuple: Contains merged results and its spatial shape. + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adaptive_padding: + x = self.adaptive_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +@MODELS.register_module() +class VAN(BaseBackbone): + """Visual Attention Network. + + A PyTorch implement of : `Visual Attention Network + `_ + + Inspiration from + https://github.com/Visual-Attention-Network/VAN-Classification + + Args: + arch (str | dict): Visual Attention Network architecture. + If use string, choose from 'tiny', 'small', 'base' and 'large'. + If use dict, it should have below keys: + + - **embed_dims** (List[int]): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **ffn_ratios** (List[int]): The number of expansion ratio of + feedforward network hidden layer channels. + + Defaults to 'tiny'. + patch_sizes (List[int | tuple]): The patch size in patch embeddings. + Defaults to [7, 3, 3, 3]. + in_channels (int): The num of input channels. Defaults to 3. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmpretrain.models import VAN + >>> import torch + >>> cfg = dict(arch='tiny') + >>> model = VAN(**cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> outputs = model(inputs) + >>> for out in outputs: + >>> print(out.size()) + (1, 256, 7, 7) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': [32, 64, 160, 256], + 'depths': [3, 3, 5, 2], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [2, 2, 4, 2], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 3, 12, 3], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 5, 27, 3], + 'ffn_ratios': [8, 8, 4, 4]}), + } # yapf: disable + + def __init__(self, + arch='tiny', + patch_sizes=[7, 3, 3, 3], + in_channels=3, + drop_rate=0., + drop_path_rate=0., + out_indices=(3, ), + frozen_stages=-1, + norm_eval=False, + norm_cfg=dict(type='LN'), + block_cfgs=dict(), + init_cfg=None): + super(VAN, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'ffn_ratios'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.ffn_ratios = self.arch_settings['ffn_ratios'] + self.num_stages = len(self.depths) + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + cur_block_idx = 0 + for i, depth in enumerate(self.depths): + patch_embed = VANPatchEmbed( + in_channels=in_channels if i == 0 else self.embed_dims[i - 1], + input_size=None, + embed_dims=self.embed_dims[i], + kernel_size=patch_sizes[i], + stride=patch_sizes[i] // 2 + 1, + padding=(patch_sizes[i] // 2, patch_sizes[i] // 2), + norm_cfg=dict(type='BN')) + + blocks = ModuleList([ + VANBlock( + embed_dims=self.embed_dims[i], + ffn_ratio=self.ffn_ratios[i], + drop_rate=drop_rate, + drop_path_rate=dpr[cur_block_idx + j], + **block_cfgs) for j in range(depth) + ]) + cur_block_idx += depth + norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1] + + self.add_module(f'patch_embed{i + 1}', patch_embed) + self.add_module(f'blocks{i + 1}', blocks) + self.add_module(f'norm{i + 1}', norm) + + def train(self, mode=True): + super(VAN, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _freeze_stages(self): + for i in range(0, self.frozen_stages + 1): + # freeze patch embed + m = getattr(self, f'patch_embed{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze blocks + m = getattr(self, f'blocks{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze norm + m = getattr(self, f'norm{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + outs = [] + for i in range(self.num_stages): + patch_embed = getattr(self, f'patch_embed{i + 1}') + blocks = getattr(self, f'blocks{i + 1}') + norm = getattr(self, f'norm{i + 1}') + x, hw_shape = patch_embed(x) + for block in blocks: + x = block(x) + x = x.flatten(2).transpose(1, 2) + x = norm(x) + x = x.reshape(-1, *hw_shape, + block.out_channels).permute(0, 3, 1, 2).contiguous() + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/vgg.py b/mmpretrain/models/backbones/vgg.py new file mode 100644 index 0000000..026b916 --- /dev/null +++ b/mmpretrain/models/backbones/vgg.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +def make_vgg_layer(in_channels, + out_channels, + num_blocks, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + dilation=1, + with_norm=False, + ceil_mode=False): + layers = [] + for _ in range(num_blocks): + layer = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + dilation=dilation, + padding=dilation, + bias=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + layers.append(layer) + in_channels = out_channels + layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode)) + + return layers + + +@MODELS.register_module() +class VGG(BaseBackbone): + """VGG backbone. + + Args: + depth (int): Depth of vgg, from {11, 13, 16, 19}. + with_norm (bool): Use BatchNorm or not. + num_classes (int): number of classes for classification. + num_stages (int): VGG stages, normally 5. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int], optional): Output from which stages. + When it is None, the default behavior depends on whether + num_classes is specified. If num_classes <= 0, the default value is + (4, ), output the last feature map before classifier. If + num_classes > 0, the default value is (5, ), output the + classification score. Default: None. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False. + with_last_pool (bool): Whether to keep the last pooling before + classifier. Default: True. + """ + + # Parameters to build layers. Each element specifies the number of conv in + # each stage. For example, VGG11 contains 11 layers with learnable + # parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3, + # where 3 indicates the last three fully-connected layers. + arch_settings = { + 11: (1, 1, 2, 2, 2), + 13: (2, 2, 2, 2, 2), + 16: (2, 2, 3, 3, 3), + 19: (2, 2, 4, 4, 4) + } + + def __init__(self, + depth, + num_classes=-1, + num_stages=5, + dilations=(1, 1, 1, 1, 1), + out_indices=None, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + norm_eval=False, + ceil_mode=False, + with_last_pool=True, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict(type='Constant', val=1., layer=['_BatchNorm']), + dict(type='Normal', std=0.01, layer=['Linear']) + ]): + super(VGG, self).__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for vgg') + assert num_stages >= 1 and num_stages <= 5 + stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + assert len(dilations) == num_stages + + self.num_classes = num_classes + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + with_norm = norm_cfg is not None + + if out_indices is None: + out_indices = (5, ) if num_classes > 0 else (4, ) + assert max(out_indices) <= num_stages + self.out_indices = out_indices + + self.in_channels = 3 + start_idx = 0 + vgg_layers = [] + self.range_sub_modules = [] + for i, num_blocks in enumerate(self.stage_blocks): + num_modules = num_blocks + 1 + end_idx = start_idx + num_modules + dilation = dilations[i] + out_channels = 64 * 2**i if i < 4 else 512 + vgg_layer = make_vgg_layer( + self.in_channels, + out_channels, + num_blocks, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dilation=dilation, + with_norm=with_norm, + ceil_mode=ceil_mode) + vgg_layers.extend(vgg_layer) + self.in_channels = out_channels + self.range_sub_modules.append([start_idx, end_idx]) + start_idx = end_idx + if not with_last_pool: + vgg_layers.pop(-1) + self.range_sub_modules[-1][1] -= 1 + self.module_name = 'features' + self.add_module(self.module_name, nn.Sequential(*vgg_layers)) + + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + outs = [] + vgg_layers = getattr(self, self.module_name) + for i in range(len(self.stage_blocks)): + for j in range(*self.range_sub_modules[i]): + vgg_layer = vgg_layers[j] + x = vgg_layer(x) + if i in self.out_indices: + outs.append(x) + if self.num_classes > 0: + x = x.view(x.size(0), -1) + x = self.classifier(x) + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + vgg_layers = getattr(self, self.module_name) + for i in range(self.frozen_stages): + for j in range(*self.range_sub_modules[i]): + m = vgg_layers[j] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(VGG, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/vig.py b/mmpretrain/models/backbones/vig.py new file mode 100644 index 0000000..c1a7879 --- /dev/null +++ b/mmpretrain/models/backbones/vig.py @@ -0,0 +1,852 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modified from +# https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer +from mmcv.cnn.bricks import DropPath +from mmengine.model import ModuleList, Sequential +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +def get_2d_relative_pos_embed(embed_dim, grid_size): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, grid_size*grid_size] + """ + pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size) + relative_pos = 2 * np.matmul(pos_embed, + pos_embed.transpose()) / pos_embed.shape[1] + return relative_pos + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, + grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, + grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def xy_pairwise_distance(x, y): + """Compute pairwise distance of a point cloud. + + Args: + x: tensor (batch_size, num_points, num_dims) + y: tensor (batch_size, num_points, num_dims) + Returns: + pairwise distance: (batch_size, num_points, num_points) + """ + with torch.no_grad(): + xy_inner = -2 * torch.matmul(x, y.transpose(2, 1)) + x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) + y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True) + return x_square + xy_inner + y_square.transpose(2, 1) + + +def xy_dense_knn_matrix(x, y, k=16, relative_pos=None): + """Get KNN based on the pairwise distance. + + Args: + x: (batch_size, num_dims, num_points, 1) + y: (batch_size, num_dims, num_points, 1) + k: int + relative_pos:Whether to use relative_pos + Returns: + nearest neighbors: + (batch_size, num_points, k) (batch_size, num_points, k) + """ + with torch.no_grad(): + x = x.transpose(2, 1).squeeze(-1) + y = y.transpose(2, 1).squeeze(-1) + batch_size, n_points, n_dims = x.shape + dist = xy_pairwise_distance(x.detach(), y.detach()) + if relative_pos is not None: + dist += relative_pos + _, nn_idx = torch.topk(-dist, k=k) + center_idx = torch.arange( + 0, n_points, device=x.device).repeat(batch_size, k, + 1).transpose(2, 1) + return torch.stack((nn_idx, center_idx), dim=0) + + +class DenseDilated(nn.Module): + """Find dilated neighbor from neighbor list. + + edge_index: (2, batch_size, num_points, k) + """ + + def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0): + super(DenseDilated, self).__init__() + self.dilation = dilation + self.use_stochastic = use_stochastic + self.epsilon = epsilon + self.k = k + + def forward(self, edge_index): + if self.use_stochastic: + if torch.rand(1) < self.epsilon and self.training: + num = self.k * self.dilation + randnum = torch.randperm(num)[:self.k] + edge_index = edge_index[:, :, :, randnum] + else: + edge_index = edge_index[:, :, :, ::self.dilation] + else: + edge_index = edge_index[:, :, :, ::self.dilation] + return edge_index + + +class DenseDilatedKnnGraph(nn.Module): + """Find the neighbors' indices based on dilated knn.""" + + def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0): + super(DenseDilatedKnnGraph, self).__init__() + self.dilation = dilation + self.use_stochastic = use_stochastic + self.epsilon = epsilon + self.k = k + self._dilated = DenseDilated(k, dilation, use_stochastic, epsilon) + + def forward(self, x, y=None, relative_pos=None): + if y is not None: + x = F.normalize(x, p=2.0, dim=1) + y = F.normalize(y, p=2.0, dim=1) + + edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, + relative_pos) + else: + x = F.normalize(x, p=2.0, dim=1) + y = x.clone() + + edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, + relative_pos) + return self._dilated(edge_index) + + +class BasicConv(Sequential): + + def __init__(self, + channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True, + drop=0.): + m = [] + for i in range(1, len(channels)): + m.append( + nn.Conv2d( + channels[i - 1], + channels[i], + 1, + bias=graph_conv_bias, + groups=4)) + if norm_cfg is not None: + m.append(build_norm_layer(norm_cfg, channels[-1])) + if act_cfg is not None: + m.append(build_activation_layer(act_cfg)) + if drop > 0: + m.append(nn.Dropout2d(drop)) + + super(BasicConv, self).__init__(*m) + + +def batched_index_select(x, idx): + r"""fetches neighbors features from a given neighbor idx + + Args: + x (Tensor): input feature Tensor + :math: + `\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`. + idx (Tensor): edge_idx + :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`. + Returns: + Tensor: output neighbors features + :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`. + """ + batch_size, num_dims, num_vertices_reduced = x.shape[:3] + _, num_vertices, k = idx.shape + idx_base = torch.arange( + 0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced + idx = idx + idx_base + idx = idx.contiguous().view(-1) + + x = x.transpose(2, 1) + feature = x.contiguous().view(batch_size * num_vertices_reduced, + -1)[idx, :] + feature = feature.view(batch_size, num_vertices, k, + num_dims).permute(0, 3, 1, 2).contiguous() + return feature + + +class MRConv2d(nn.Module): + """Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) + for dense data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(MRConv2d, self).__init__() + self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg, + graph_conv_bias) + + def forward(self, x, edge_index, y=None): + x_i = batched_index_select(x, edge_index[1]) + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + x_j, _ = torch.max(x_j - x_i, -1, keepdim=True) + b, c, n, _ = x.shape + x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], + dim=2).reshape(b, 2 * c, n, _) + return self.nn(x) + + +class EdgeConv2d(nn.Module): + """Edge convolution layer (with activation, batch normalization) for dense + data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(EdgeConv2d, self).__init__() + self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg, + graph_conv_bias) + + def forward(self, x, edge_index, y=None): + x_i = batched_index_select(x, edge_index[1]) + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + max_value, _ = torch.max( + self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) + return max_value + + +class GraphSAGE(nn.Module): + """GraphSAGE Graph Convolution (Paper: https://arxiv.org/abs/1706.02216) + for dense data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(GraphSAGE, self).__init__() + self.nn1 = BasicConv([in_channels, in_channels], act_cfg, norm_cfg, + graph_conv_bias) + self.nn2 = BasicConv([in_channels * 2, out_channels], act_cfg, + norm_cfg, graph_conv_bias) + + def forward(self, x, edge_index, y=None): + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + x_j, _ = torch.max(self.nn1(x_j), -1, keepdim=True) + return self.nn2(torch.cat([x, x_j], dim=1)) + + +class GINConv2d(nn.Module): + """GIN Graph Convolution (Paper: https://arxiv.org/abs/1810.00826) for + dense data type.""" + + def __init__(self, + in_channels, + out_channels, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(GINConv2d, self).__init__() + self.nn = BasicConv([in_channels, out_channels], act_cfg, norm_cfg, + graph_conv_bias) + eps_init = 0.0 + self.eps = nn.Parameter(torch.Tensor([eps_init])) + + def forward(self, x, edge_index, y=None): + if y is not None: + x_j = batched_index_select(y, edge_index[0]) + else: + x_j = batched_index_select(x, edge_index[0]) + x_j = torch.sum(x_j, -1, keepdim=True) + return self.nn((1 + self.eps) * x + x_j) + + +class GraphConv2d(nn.Module): + """Static graph convolution layer.""" + + def __init__(self, + in_channels, + out_channels, + graph_conv_type, + act_cfg, + norm_cfg=None, + graph_conv_bias=True): + super(GraphConv2d, self).__init__() + if graph_conv_type == 'edge': + self.gconv = EdgeConv2d(in_channels, out_channels, act_cfg, + norm_cfg, graph_conv_bias) + elif graph_conv_type == 'mr': + self.gconv = MRConv2d(in_channels, out_channels, act_cfg, norm_cfg, + graph_conv_bias) + elif graph_conv_type == 'sage': + self.gconv = GraphSAGE(in_channels, out_channels, act_cfg, + norm_cfg, graph_conv_bias) + elif graph_conv_type == 'gin': + self.gconv = GINConv2d(in_channels, out_channels, act_cfg, + norm_cfg, graph_conv_bias) + else: + raise NotImplementedError( + 'graph_conv_type:{} is not supported'.format(graph_conv_type)) + + def forward(self, x, edge_index, y=None): + return self.gconv(x, edge_index, y) + + +class DyGraphConv2d(GraphConv2d): + """Dynamic graph convolution layer.""" + + def __init__(self, + in_channels, + out_channels, + k=9, + dilation=1, + graph_conv_type='mr', + act_cfg=dict(type='GELU'), + norm_cfg=None, + graph_conv_bias=True, + use_stochastic=False, + epsilon=0.2, + r=1): + super(DyGraphConv2d, + self).__init__(in_channels, out_channels, graph_conv_type, + act_cfg, norm_cfg, graph_conv_bias) + self.k = k + self.d = dilation + self.r = r + self.dilated_knn_graph = DenseDilatedKnnGraph(k, dilation, + use_stochastic, epsilon) + + def forward(self, x, relative_pos=None): + B, C, H, W = x.shape + y = None + if self.r > 1: + y = F.avg_pool2d(x, self.r, self.r) + y = y.reshape(B, C, -1, 1).contiguous() + x = x.reshape(B, C, -1, 1).contiguous() + edge_index = self.dilated_knn_graph(x, y, relative_pos) + x = super(DyGraphConv2d, self).forward(x, edge_index, y) + return x.reshape(B, -1, H, W).contiguous() + + +class Grapher(nn.Module): + """Grapher module with graph convolution and fc layers.""" + + def __init__(self, + in_channels, + k=9, + dilation=1, + graph_conv_type='mr', + act_cfg=dict(type='GELU'), + norm_cfg=None, + graph_conv_bias=True, + use_stochastic=False, + epsilon=0.2, + r=1, + n=196, + drop_path=0.0, + relative_pos=False): + super(Grapher, self).__init__() + self.channels = in_channels + self.n = n + self.r = r + self.fc1 = Sequential( + nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), in_channels), + ) + self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, k, + dilation, graph_conv_type, act_cfg, + norm_cfg, graph_conv_bias, + use_stochastic, epsilon, r) + self.fc2 = Sequential( + nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), in_channels), + ) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.relative_pos = None + if relative_pos: + relative_pos_tensor = torch.from_numpy( + np.float32( + get_2d_relative_pos_embed(in_channels, int( + n**0.5)))).unsqueeze(0).unsqueeze(1) + relative_pos_tensor = F.interpolate( + relative_pos_tensor, + size=(n, n // (r * r)), + mode='bicubic', + align_corners=False) + self.relative_pos = nn.Parameter( + -relative_pos_tensor.squeeze(1), requires_grad=False) + + def _get_relative_pos(self, relative_pos, H, W): + if relative_pos is None or H * W == self.n: + return relative_pos + else: + N = H * W + N_reduced = N // (self.r * self.r) + return F.interpolate( + relative_pos.unsqueeze(0), size=(N, N_reduced), + mode='bicubic').squeeze(0) + + def forward(self, x): + B, C, H, W = x.shape + relative_pos = self._get_relative_pos(self.relative_pos, H, W) + shortcut = x + x = self.fc1(x) + x = self.graph_conv(x, relative_pos) + x = self.fc2(x) + x = self.drop_path(x) + shortcut + return x + + +class FFN(nn.Module): + """"out_features = out_features or in_features\n + hidden_features = hidden_features or in_features""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop_path=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = Sequential( + nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), hidden_features), + ) + self.act = build_activation_layer(act_cfg) + self.fc2 = Sequential( + nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0), + build_norm_layer(dict(type='BN'), out_features), + ) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop_path(x) + shortcut + return x + + +@MODELS.register_module() +class Vig(BaseBackbone): + """Vision GNN backbone. + + A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes + `_. + + Modified from the official implementation + https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch + + Args: + arch(str): Vision GNN architecture, + choose from 'tiny', 'small' and 'base'. + in_channels (int): The number of channels of input images. + Defaults to 3. + k (int): The number of KNN's k. Defaults to 9. + out_indices (Sequence | int): Output from which blocks. + Defaults to -1, means the last block. + act_cfg (dict): The config of activative functions. + Defaults to ``dict(type='GELU'))``. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='BN', eps=1e-6)``. + graph_conv_bias (bool): Whether to use bias in the convolution + layers in Grapher. Defaults to True. + graph_conv_type (str): The type of graph convolution,choose + from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'. + epsilon (float): Probability of random arrangement in KNN. It only + works when ``use_dilation=True`` and ``use_stochastic=True``. + Defaults to 0.2. + use_dilation(bool): Whether to use dilation in KNN. Defaults to True. + use_stochastic(bool): Whether to use stochastic in KNN. + Defaults to False. + drop_path (float): stochastic depth rate. Default 0.0 + relative_pos(bool): Whether to use relative position embedding. + Defaults to False. + norm_eval (bool): Whether to set the normalization layer to eval mode. + Defaults to False. + frozen_stages (int): Blocks to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): The initialization configs. + Defaults to None. + """ # noqa: E501 + + arch_settings = { + 'tiny': dict(num_blocks=12, channels=192), + 'small': dict(num_blocks=16, channels=320), + 'base': dict(num_blocks=16, channels=640), + } + + def __init__(self, + arch, + in_channels=3, + k=9, + out_indices=-1, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN'), + graph_conv_bias=True, + graph_conv_type='mr', + epsilon=0.2, + use_dilation=True, + use_stochastic=False, + drop_path=0., + relative_pos=False, + norm_eval=False, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + arch = self.arch_settings[arch] + self.num_blocks = arch['num_blocks'] + channels = arch['channels'] + + if isinstance(out_indices, int): + out_indices = [out_indices] + elif isinstance(out_indices, tuple): + out_indices = list(out_indices) + elif not isinstance(out_indices, list): + raise TypeError('"out_indices" must by a tuple, list or int, ' + f'get {type(out_indices)} instead.') + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_blocks + index + assert 0 <= out_indices[i] <= self.num_blocks, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.stem = Sequential( + nn.Conv2d(in_channels, channels // 8, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels // 8), + build_activation_layer(act_cfg), + nn.Conv2d(channels // 8, channels // 4, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels // 4), + build_activation_layer(act_cfg), + nn.Conv2d(channels // 4, channels // 2, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels // 2), + build_activation_layer(act_cfg), + nn.Conv2d(channels // 2, channels, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels), + build_activation_layer(act_cfg), + nn.Conv2d(channels, channels, 3, stride=1, padding=1), + build_norm_layer(norm_cfg, channels), + ) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)] + # number of knn's k + num_knn = [ + int(x.item()) for x in torch.linspace(k, 2 * k, self.num_blocks) + ] + max_dilation = 196 // max(num_knn) + + self.pos_embed = nn.Parameter(torch.zeros(1, channels, 14, 14)) + + self.blocks = ModuleList([ + Sequential( + Grapher( + in_channels=channels, + k=num_knn[i], + dilation=min(i // 4 + + 1, max_dilation) if use_dilation else 1, + graph_conv_type=graph_conv_type, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + graph_conv_bias=graph_conv_bias, + use_stochastic=use_stochastic, + epsilon=epsilon, + drop_path=dpr[i], + relative_pos=relative_pos), + FFN(in_features=channels, + hidden_features=channels * 4, + act_cfg=act_cfg, + drop_path=dpr[i])) for i in range(self.num_blocks) + ]) + + self.norm_eval = norm_eval + self.frozen_stages = frozen_stages + + def forward(self, inputs): + outs = [] + x = self.stem(inputs) + self.pos_embed + + for i, block in enumerate(self.blocks): + x = block(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + self.stem.eval() + for i in range(self.frozen_stages): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(Vig, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@MODELS.register_module() +class PyramidVig(BaseBackbone): + """Pyramid Vision GNN backbone. + + A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes + `_. + + Modified from the official implementation + https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch + + Args: + arch (str): Vision GNN architecture, choose from 'tiny', + 'small' and 'base'. + in_channels (int): The number of channels of input images. + Defaults to 3. + k (int): The number of KNN's k. Defaults to 9. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + act_cfg (dict): The config of activative functions. + Defaults to ``dict(type='GELU'))``. + norm_cfg (dict): The config of normalization layers. + Defaults to ``dict(type='BN')``. + graph_conv_bias (bool): Whether to use bias in the convolution + layers in Grapher. Defaults to True. + graph_conv_type (str): The type of graph convolution,choose + from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'. + epsilon (float): Probability of random arrangement in KNN. It only + works when ``use_stochastic=True``. Defaults to 0.2. + use_stochastic (bool): Whether to use stochastic in KNN. + Defaults to False. + drop_path (float): stochastic depth rate. Default 0.0 + norm_eval (bool): Whether to set the normalization layer to eval mode. + Defaults to False. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + init_cfg (dict, optional): The initialization configs. + Defaults to None. + """ # noqa: E501 + arch_settings = { + 'tiny': dict(blocks=[2, 2, 6, 2], channels=[48, 96, 240, 384]), + 'small': dict(blocks=[2, 2, 6, 2], channels=[80, 160, 400, 640]), + 'medium': dict(blocks=[2, 2, 16, 2], channels=[96, 192, 384, 768]), + 'base': dict(blocks=[2, 2, 18, 2], channels=[128, 256, 512, 1024]), + } + + def __init__(self, + arch, + in_channels=3, + k=9, + out_indices=-1, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN'), + graph_conv_bias=True, + graph_conv_type='mr', + epsilon=0.2, + use_stochastic=False, + drop_path=0., + norm_eval=False, + frozen_stages=0, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + arch = self.arch_settings[arch] + self.blocks = arch['blocks'] + self.num_blocks = sum(self.blocks) + self.num_stages = len(self.blocks) + channels = arch['channels'] + self.channels = channels + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert 0 <= out_indices[i] <= self.num_stages, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.stem = Sequential( + nn.Conv2d(in_channels, channels[0] // 2, 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels[0] // 2), + build_activation_layer(act_cfg), + nn.Conv2d(channels[0] // 2, channels[0], 3, stride=2, padding=1), + build_norm_layer(norm_cfg, channels[0]), + build_activation_layer(act_cfg), + nn.Conv2d(channels[0], channels[0], 3, stride=1, padding=1), + build_norm_layer(norm_cfg, channels[0]), + ) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)] + # number of knn's k + num_knn = [ + int(x.item()) for x in torch.linspace(k, k, self.num_blocks) + ] + max_dilation = 49 // max(num_knn) + + self.pos_embed = nn.Parameter( + torch.zeros(1, channels[0], 224 // 4, 224 // 4)) + HW = 224 // 4 * 224 // 4 + reduce_ratios = [4, 2, 1, 1] + + self.stages = ModuleList() + block_idx = 0 + for stage_idx, num_blocks in enumerate(self.blocks): + mid_channels = channels[stage_idx] + reduce_ratio = reduce_ratios[stage_idx] + blocks = [] + if stage_idx > 0: + blocks.append( + Sequential( + nn.Conv2d( + self.channels[stage_idx - 1], + mid_channels, + kernel_size=3, + stride=2, + padding=1), + build_norm_layer(norm_cfg, mid_channels), + )) + HW = HW // 4 + for _ in range(num_blocks): + blocks.append( + Sequential( + Grapher( + in_channels=mid_channels, + k=num_knn[block_idx], + dilation=min(block_idx // 4 + 1, max_dilation), + graph_conv_type=graph_conv_type, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + graph_conv_bias=graph_conv_bias, + use_stochastic=use_stochastic, + epsilon=epsilon, + r=reduce_ratio, + n=HW, + drop_path=dpr[block_idx], + relative_pos=True), + FFN(in_features=mid_channels, + hidden_features=mid_channels * 4, + act_cfg=act_cfg, + drop_path=dpr[block_idx]))) + block_idx += 1 + self.stages.append(Sequential(*blocks)) + + self.norm_eval = norm_eval + self.frozen_stages = frozen_stages + + def forward(self, inputs): + outs = [] + x = self.stem(inputs) + self.pos_embed + + for i, blocks in enumerate(self.stages): + x = blocks(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + self.stem.eval() + for i in range(self.frozen_stages): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(PyramidVig, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/backbones/vision_transformer.py b/mmpretrain/models/backbones/vision_transformer.py new file mode 100644 index 0000000..a54053c --- /dev/null +++ b/mmpretrain/models/backbones/vision_transformer.py @@ -0,0 +1,537 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import (MultiheadAttention, SwiGLUFFNFused, build_norm_layer, + resize_pos_embed, to_2tuple) +from .base_backbone import BaseBackbone + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension for FFNs + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 0. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + ffn_type (str): Select the type of ffn layers. Defaults to 'origin'. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + layer_scale_init_value=0., + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + ffn_type='origin', + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + layer_scale_init_value=layer_scale_init_value) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + if ffn_type == 'origin': + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + layer_scale_init_value=layer_scale_init_value) + elif ffn_type == 'swiglu_fused': + self.ffn = SwiGLUFFNFused( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + layer_scale_init_value=layer_scale_init_value) + else: + raise NotImplementedError + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def init_weights(self): + super(TransformerEncoderLayer, self).init_weights() + for m in self.ffn.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) + return x + + +@MODELS.register_module() +class VisionTransformer(BaseBackbone): + """Vision Transformer. + + A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers + for Image Recognition at Scale `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 0. + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 768, + 'num_layers': 8, + 'num_heads': 8, + 'feedforward_channels': 768 * 3, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['h', 'huge'], + { + # The same as the implementation in MAE + # + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120 + }), + **dict.fromkeys( + ['eva-g', 'eva-giant'], + { + # The implementation in EVA + # + 'embed_dims': 1408, + 'num_layers': 40, + 'num_heads': 16, + 'feedforward_channels': 6144 + }), + **dict.fromkeys( + ['deit-t', 'deit-tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': 192 * 4 + }), + **dict.fromkeys( + ['deit-s', 'deit-small', 'dinov2-s', 'dinov2-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 384 * 4 + }), + **dict.fromkeys( + ['deit-b', 'deit-base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 768 * 4 + }), + **dict.fromkeys( + ['dinov2-g', 'dinov2-giant'], { + 'embed_dims': 1536, + 'num_layers': 40, + 'num_heads': 24, + 'feedforward_channels': 6144 + }), + } + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + out_type='cls_token', + with_cls_token=True, + frozen_stages=-1, + interpolate_mode='bicubic', + layer_scale_init_value=0., + patch_cfg=dict(), + layer_cfgs=dict(), + pre_norm=False, + init_cfg=None): + super(VisionTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + bias=not pre_norm, # disable bias if pre_norm is used(e.g., CLIP) + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + # Set cls token + self.with_cls_token = with_cls_token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_extra_tokens, + self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + layer_scale_init_value=layer_scale_init_value, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(TransformerEncoderLayer(**_layer_cfg)) + + self.frozen_stages = frozen_stages + if pre_norm: + self.pre_norm = build_norm_layer(norm_cfg, self.embed_dims) + else: + self.pre_norm = nn.Identity() + + self.final_norm = final_norm + if final_norm: + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + if self.out_type == 'avg_featmap': + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def init_weights(self): + super(VisionTransformer, self).init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if (not self.with_cls_token + and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1): + # Remove cls token from state dict if it's not used. + state_dict[name] = state_dict[name][:, 1:] + ckpt_pos_embed_shape = state_dict[name].shape + + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + @staticmethod + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze pre-norm + for param in self.pre_norm.parameters(): + param.requires_grad = False + # freeze cls_token + if self.cls_token is not None: + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers): + if self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + if self.out_type == 'avg_featmap': + self.ln2.eval() + for param in self.ln2.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return self.ln2(patch_token.mean(dim=1)) + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in ('cls_token', 'pos_embed'): + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/vit_eva02.py b/mmpretrain/models/backbones/vit_eva02.py new file mode 100644 index 0000000..20ec4b2 --- /dev/null +++ b/mmpretrain/models/backbones/vit_eva02.py @@ -0,0 +1,350 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.drop import build_dropout +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from ..utils import (RotaryEmbeddingFast, SwiGLUFFN, build_norm_layer, + resize_pos_embed) +from .vision_transformer import VisionTransformer + + +class AttentionWithRoPE(BaseModule): + """Multi-head Attention Module with 2D sincos position embedding (RoPE). + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + qkv_bias (bool): If True, add a learnable bias to q and v. Note + that we follows the official implementation where ``k_bias`` + is 0. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + rope (:obj:`torch.nn.Module`, optional): If it is an object of the + ``RotaryEmbedding``, the rotation of the token position will be + performed before the softmax. Defaults to None. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + qkv_bias=True, + qk_scale=None, + proj_bias=True, + rope=None, + with_cls_token=True, + init_cfg=None): + super(AttentionWithRoPE, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.num_heads = num_heads + self.head_dims = embed_dims // num_heads + self.scale = qk_scale or self.head_dims**-0.5 + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.with_cls_token = with_cls_token + + self.rope = rope + + def forward(self, x, patch_resolution): + B, N, _ = x.shape + + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(dim=0) + + if self.rope: + if self.with_cls_token: + q_t = q[:, :, 1:, :] + ro_q_t = self.rope(q_t, patch_resolution) + q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, 1:, :] if self.with_cls_token else k + ro_k_t = self.rope(k_t, patch_resolution) + k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) + else: + q = self.rope(q, patch_resolution) + k = self.rope(k, patch_resolution) + + q = q * self.scale + + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1).type_as(x) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class EVA02EndcoderLayer(BaseModule): + """Implements one encoder EVA02EndcoderLayer in EVA02. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension of FFNs. + sub_ln (bool): Whether to add the sub layer normalization + in the attention module. Defaults to False. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool): enable bias for projection in the attention module + if True. Defaults to True. + rope (:obj:`torch.nn.Module`, optional): RotaryEmbedding object + in the attention module. Defaults to None. + drop_rate (float): Dropout rate in the mlp module. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + sub_ln=False, + attn_drop=0., + proj_drop=0., + qkv_bias=False, + qk_scale=None, + proj_bias=True, + rope=None, + with_cls_token=True, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + init_cfg=None): + super(EVA02EndcoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims) + + self.attn = AttentionWithRoPE( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + proj_bias=proj_bias, + rope=rope, + with_cls_token=with_cls_token) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate)) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims) + + if drop_rate > 0: + dropout_layer = dict(type='Dropout', drop_prob=drop_rate) + else: + dropout_layer = None + + if sub_ln: + ffn_norm = norm_cfg + else: + ffn_norm = None + + self.mlp = SwiGLUFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + dropout_layer=dropout_layer, + norm_cfg=ffn_norm, + add_identity=False, + ) + + def forward(self, x, patch_resolution): + inputs = x + x = self.norm1(x) + x = self.attn(x, patch_resolution) + x = self.drop_path(x) + x = inputs + x + + inputs = x + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = inputs + x + + return x + + +@MODELS.register_module() +class ViTEVA02(VisionTransformer): + """EVA02 Vision Transformer. + + A PyTorch implement of : `EVA-02: A Visual Representation for Neon Genesis + `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'tiny', 'small', 'base', 'large'. If use dict, + it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **mlp_ratio** (float): The ratio of the mlp module. + + Defaults to 'tiny'. + + sub_ln (bool): Whether to add the sub layer normalization in swiglu. + Defaults to False. + drop_rate (float): Probability of an element to be zeroed in the + mlp module. Defaults to 0. + attn_drop_rate (float): Probability of an element to be zeroed after + the softmax in the attention. Defaults to 0. + proj_drop_rate (float): Probability of an element to be zeroed after + projection in the attention. Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + **kwargs(dict, optional): Other args for Vision Transformer. + """ + arch_zoo = { + **dict.fromkeys( + ['t', 'ti', 'tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': int(192 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': int(384 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': int(768 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': int(1024 * 4 * 2 / 3) + }) + } + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='tiny', + sub_ln=False, + drop_rate=0., + attn_drop_rate=0., + proj_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN'), + with_cls_token=True, + layer_cfgs=dict(), + **kwargs): + # set essential args for Vision Transformer + kwargs.update( + arch=arch, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + with_cls_token=with_cls_token) + super(ViTEVA02, self).__init__(**kwargs) + + self.num_heads = self.arch_settings['num_heads'] + + # Set RoPE + head_dim = self.embed_dims // self.num_heads + self.rope = RotaryEmbeddingFast( + embed_dims=head_dim, patch_resolution=self.patch_resolution) + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self. + arch_settings['feedforward_channels'], + sub_ln=sub_ln, + norm_cfg=norm_cfg, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_rate=drop_rate, + qkv_bias=qkv_bias, + rope=self.rope, + with_cls_token=with_cls_token, + drop_path_rate=dpr[i]) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(EVA02EndcoderLayer(**_layer_cfg)) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, patch_resolution) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) diff --git a/mmpretrain/models/backbones/vit_sam.py b/mmpretrain/models/backbones/vit_sam.py new file mode 100644 index 0000000..0eb46a7 --- /dev/null +++ b/mmpretrain/models/backbones/vit_sam.py @@ -0,0 +1,697 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from ..utils import LayerNorm2d, build_norm_layer, resize_pos_embed, to_2tuple +from .base_backbone import BaseBackbone + + +def window_partition(x: torch.Tensor, + window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """Partition into non-overlapping windows with padding if needed. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + x (torch.Tensor): Input tokens with [B, H, W, C]. + window_size (int): Window size. + + Returns: + Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + + - ``windows``: Windows after partition with + [B * num_windows, window_size, window_size, C]. + - ``(Hp, Wp)``: Padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows: torch.Tensor, window_size: int, + pad_hw: Tuple[int, int], + hw: Tuple[int, int]) -> torch.Tensor: + """Window unpartition into original sequences and removing padding. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + x (torch.Tensor): Input tokens with + [B * num_windows, window_size, window_size, C]. + window_size (int): Window size. + pad_hw (tuple): Padded height and width (Hp, Wp). + hw (tuple): Original height and width (H, W) before padding. + + Returns: + torch.Tensor: Unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, + rel_pos: torch.Tensor) -> torch.Tensor: + """Get relative positional embeddings according to the relative positions + of query and key sizes. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + q_size (int): Size of query q. + k_size (int): Size of key k. + rel_pos (torch.Tensor): Relative position embeddings (L, C). + + Returns: + torch.Tensor: Extracted positional embeddings according to relative + positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode='linear', + ) + rel_pos_resized = rel_pos_resized.reshape(-1, + max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - + k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """Borrowed from https://github.com/facebookresearch/segment-anything/ + + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (torch.Tensor): Attention map. + q (torch.Tensor): Query q in the attention layer with shape + (B, q_h * q_w, C). + rel_pos_h (torch.Tensor): Relative position embeddings (Lh, C) for + height axis. + rel_pos_w (torch.Tensor): Relative position embeddings (Lw, C) for + width axis. + q_size (tuple): Spatial sequence size of query q with (q_h, q_w). + k_size (tuple): Spatial sequence size of key k with (k_h, k_w). + + Returns: + torch.Tensor: Attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh) + rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw) + + attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings. + + Borrowed from https://github.com/facebookresearch/segment-anything/ + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + use_rel_pos (bool):Whether to use relative position embedding. + Defaults to False. + input_size (int, optional): Input resolution for calculating the + relative positional parameter size. Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = head_embed_dims**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dims, embed_dims) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert (input_size is not None), \ + 'Input size must be provided if using relative position embed.' + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter( + torch.zeros(2 * input_size[0] - 1, head_embed_dims)) + self.rel_pos_w = nn.Parameter( + torch.zeros(2 * input_size[1] - 1, head_embed_dims)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, + -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, + self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, + -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +class TransformerEncoderLayer(BaseModule): + """Encoder layer with window attention in Vision Transformer. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension for FFNs + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + use_rel_pos (bool):Whether to use relative position embedding. + Defaults to False. + window_size (int): Window size for window attention. Defaults to 0. + input_size (int, optional): Input resolution for calculating the + relative positional parameter size. Defaults to None. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + qkv_bias: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + use_rel_pos: bool = False, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.window_size = window_size + + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) + + self.attn = Attention( + embed_dims=embed_dims, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + input_size=input_size if window_size == 0 else + (window_size, window_size), + ) + + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def forward(self, x): + shortcut = x + x = self.ln1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + x = shortcut + x + + x = self.ffn(self.ln2(x), identity=x) + return x + + +@MODELS.register_module() +class ViTSAM(BaseBackbone): + """Vision Transformer as image encoder used in SAM. + + A PyTorch implement of backbone: `Segment Anything + `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'base', 'large', 'huge'. If use dict, it should have + below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + - **global_attn_indexes** (int): The index of layers with global + attention. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_channels (int): The num of output channels, if equal to 0, the + channel reduction layer is disabled. Defaults to 256. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + out_type (str): The type of output features. Please choose from + + - ``"raw"`` or ``"featmap"``: The feature map tensor from the + patch tokens with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + + Defaults to ``"raw"``. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + use_abs_pos (bool): Whether to use absolute position embedding. + Defaults to True. + use_rel_pos (bool):Whether to use relative position embedding. + Defaults to True. + window_size (int): Window size for window attention. Defaults to 14. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072, + 'global_attn_indexes': [2, 5, 8, 11] + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096, + 'global_attn_indexes': [5, 11, 17, 23] + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120, + 'global_attn_indexes': [7, 15, 23, 31] + }), + } + OUT_TYPES = {'raw', 'featmap', 'avg_featmap'} + + def __init__(self, + arch: str = 'base', + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + out_channels: int = 256, + out_indices: int = -1, + out_type: str = 'raw', + drop_rate: float = 0., + drop_path_rate: float = 0., + qkv_bias: bool = True, + use_abs_pos: bool = True, + use_rel_pos: bool = True, + window_size: int = 14, + norm_cfg: dict = dict(type='LN', eps=1e-6), + frozen_stages: int = -1, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: Optional[dict] = None): + super().__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.global_attn_indexes = self.arch_settings['global_attn_indexes'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + + self.use_abs_pos = use_abs_pos + self.interpolate_mode = interpolate_mode + if use_abs_pos: + # Set position embedding + self.pos_embed = nn.Parameter( + torch.zeros(1, *self.patch_resolution, self.embed_dims)) + self.drop_after_pos = nn.Dropout(p=drop_rate) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + if use_rel_pos: + self._register_load_state_dict_pre_hook( + self._prepare_relative_position) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + window_size=window_size + if i not in self.global_attn_indexes else 0, + input_size=self.patch_resolution, + use_rel_pos=use_rel_pos, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(TransformerEncoderLayer(**_layer_cfg)) + + self.out_channels = out_channels + if self.out_channels > 0: + self.channel_reduction = nn.Sequential( + nn.Conv2d( + self.embed_dims, + out_channels, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_channels, eps=1e-6), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_channels, eps=1e-6), + ) + + # freeze stages only when self.frozen_stages > 0 + self.frozen_stages = frozen_stages + if self.frozen_stages > 0: + self._freeze_stages() + + def init_weights(self): + super().init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze channel_reduction module + if self.frozen_stages == self.num_layers and self.out_channels > 0: + m = self.channel_reduction + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + x = x.view(B, patch_resolution[0], patch_resolution[1], + self.embed_dims) + + if self.use_abs_pos: + # 'resize_pos_embed' only supports 'pos_embed' with ndim==3, but + # in ViTSAM, the 'pos_embed' has 4 dimensions (1, H, W, C), so it + # is flattened. Besides, ViTSAM doesn't have any extra token. + resized_pos_embed = resize_pos_embed( + self.pos_embed.flatten(1, 2), + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=0) + x = x + resized_pos_embed.view(1, *patch_resolution, + self.embed_dims) + x = self.drop_after_pos(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i in self.out_indices: + # (B, H, W, C) -> (B, C, H, W) + x_reshape = x.permute(0, 3, 1, 2) + + if self.out_channels > 0: + x_reshape = self.channel_reduction(x_reshape) + outs.append(self._format_output(x_reshape)) + + return tuple(outs) + + def _format_output(self, x) -> torch.Tensor: + if self.out_type == 'raw' or self.out_type == 'featmap': + return x + elif self.out_type == 'avg_featmap': + # (B, C, H, W) -> (B, C, N) -> (B, N, C) + x = x.flatten(2).permute(0, 2, 1) + return x.mean(dim=1) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = ckpt_pos_embed_shape[1:3] + pos_embed_shape = self.patch_embed.init_out_size + + flattened_pos_embed = state_dict[name].flatten(1, 2) + resized_pos_embed = resize_pos_embed(flattened_pos_embed, + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, 0) + state_dict[name] = resized_pos_embed.view(1, *pos_embed_shape, + self.embed_dims) + + def _prepare_relative_position(self, state_dict, prefix, *args, **kwargs): + state_dict_model = self.state_dict() + all_keys = list(state_dict_model.keys()) + for key in all_keys: + if 'rel_pos_' in key: + ckpt_key = prefix + key + if ckpt_key not in state_dict: + continue + relative_position_pretrained = state_dict[ckpt_key] + relative_position_current = state_dict_model[key] + L1, _ = relative_position_pretrained.size() + L2, _ = relative_position_current.size() + if L1 != L2: + new_rel_pos = F.interpolate( + relative_position_pretrained.reshape(1, L1, + -1).permute( + 0, 2, 1), + size=L2, + mode='linear', + ) + new_rel_pos = new_rel_pos.reshape(-1, L2).permute(1, 0) + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info(f'Resize the {ckpt_key} from ' + f'{state_dict[ckpt_key].shape} to ' + f'{new_rel_pos.shape}') + state_dict[ckpt_key] = new_rel_pos + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in ('cls_token', 'pos_embed'): + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/xcit.py b/mmpretrain/models/backbones/xcit.py new file mode 100644 index 0000000..392ebbe --- /dev/null +++ b/mmpretrain/models/backbones/xcit.py @@ -0,0 +1,770 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from functools import partial +from typing import Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks import ConvModule, DropPath +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule, Sequential +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils import digit_version + +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer, to_2tuple +from .base_backbone import BaseBackbone + +if digit_version(torch.__version__) < digit_version('1.8.0'): + floor_div = torch.floor_divide +else: + floor_div = partial(torch.div, rounding_mode='floor') + + +class ClassAttntion(BaseModule): + """Class Attention Module. + + A PyTorch implementation of Class Attention Module introduced by: + `Going deeper with Image Transformers `_ + + taken from + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + with slight modifications to do CA + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. Defaults to 8. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + proj_drop (float): The drop out rate for linear output weights. + Defaults to 0. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ # noqa: E501 + + def __init__(self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + init_cfg=None): + + super(ClassAttntion, self).__init__(init_cfg=init_cfg) + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + # We only need to calculate query of cls token. + q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, + C // self.num_heads).permute( + 0, 2, 1, 3) + k = self.k(x).reshape(B, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + q = q * self.scale + v = self.v(x).reshape(B, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) + x_cls = self.proj(x_cls) + x_cls = self.proj_drop(x_cls) + + return x_cls + + +class PositionalEncodingFourier(BaseModule): + """Positional Encoding using a fourier kernel. + + A PyTorch implementation of Positional Encoding relying on + a fourier kernel introduced by: + `Attention is all you Need `_ + + Based on the `official XCiT code + `_ + + Args: + hidden_dim (int): The hidden feature dimension. Defaults to 32. + dim (int): The output feature dimension. Defaults to 768. + temperature (int): A control variable for position encoding. + Defaults to 10000. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + hidden_dim: int = 32, + dim: int = 768, + temperature: int = 10000, + init_cfg=None): + super(PositionalEncodingFourier, self).__init__(init_cfg=init_cfg) + + self.token_projection = ConvModule( + in_channels=hidden_dim * 2, + out_channels=dim, + kernel_size=1, + conv_cfg=None, + norm_cfg=None, + act_cfg=None) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + self.eps = 1e-6 + + def forward(self, B: int, H: int, W: int): + device = self.token_projection.conv.weight.device + y_embed = torch.arange( + 1, H + 1, device=device).unsqueeze(1).repeat(1, 1, W).float() + x_embed = torch.arange(1, W + 1, device=device).repeat(1, H, 1).float() + y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale + + dim_t = torch.arange(self.hidden_dim, device=device).float() + dim_t = floor_div(dim_t, 2) + dim_t = self.temperature**(2 * dim_t / self.hidden_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + [pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], + dim=4).flatten(3) + pos_y = torch.stack( + [pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos) + return pos.repeat(B, 1, 1, 1) # (B, C, H, W) + + +class ConvPatchEmbed(BaseModule): + """Patch Embedding using multiple convolution layers. + + Args: + img_size (int, tuple): input image size. + Defaults to 224, means the size is 224*224. + patch_size (int): The patch size in conv patch embedding. + Defaults to 16. + in_channels (int): The input channels of this module. + Defaults to 3. + embed_dims (int): The feature dimension + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + img_size: Union[int, tuple] = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dims: int = 768, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(ConvPatchEmbed, self).__init__(init_cfg=init_cfg) + img_size = to_2tuple(img_size) + num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + conv = partial( + ConvModule, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + layer = [] + if patch_size == 16: + layer.append( + conv(in_channels=in_channels, out_channels=embed_dims // 8)) + layer.append( + conv( + in_channels=embed_dims // 8, out_channels=embed_dims // 4)) + elif patch_size == 8: + layer.append( + conv(in_channels=in_channels, out_channels=embed_dims // 4)) + else: + raise ValueError('For patch embedding, the patch size must be 16 ' + f'or 8, but get patch size {self.patch_size}.') + + layer.append( + conv(in_channels=embed_dims // 4, out_channels=embed_dims // 2)) + layer.append( + conv( + in_channels=embed_dims // 2, + out_channels=embed_dims, + act_cfg=None, + )) + + self.proj = Sequential(*layer) + + def forward(self, x: torch.Tensor): + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + x = x.flatten(2).transpose(1, 2) # (B, N, C) + return x, (Hp, Wp) + + +class ClassAttentionBlock(BaseModule): + """Transformer block using Class Attention. + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. + mlp_ratio (float): The hidden dimension ratio for FFN. + Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The initial value for layer scale. + Defaults to 1. + tokens_norm (bool): Whether to normalize all tokens or just the + cls_token in the CA. Defaults to False. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN', eps=1e-6)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + drop=0., + attn_drop=0., + drop_path=0., + layer_scale_init_value=1., + tokens_norm=False, + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=None): + + super(ClassAttentionBlock, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, dim) + + self.attn = ClassAttntion( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = build_norm_layer(norm_cfg, dim) + + self.ffn = FFN( + embed_dims=dim, + feedforward_channels=int(dim * mlp_ratio), + act_cfg=act_cfg, + ffn_drop=drop, + ) + + if layer_scale_init_value > 0: + self.gamma1 = nn.Parameter(layer_scale_init_value * + torch.ones(dim)) + self.gamma2 = nn.Parameter(layer_scale_init_value * + torch.ones(dim)) + else: + self.gamma1, self.gamma2 = 1.0, 1.0 + + # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501 + self.tokens_norm = tokens_norm + + def forward(self, x): + x_norm1 = self.norm1(x) + x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1) + x = x + self.drop_path(self.gamma1 * x_attn) + if self.tokens_norm: + x = self.norm2(x) + else: + x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1) + x_res = x + cls_token = x[:, 0:1] + cls_token = self.gamma2 * self.ffn(cls_token, identity=0) + x = torch.cat([cls_token, x[:, 1:]], dim=1) + x = x_res + self.drop_path(x) + return x + + +class LPI(BaseModule): + """Local Patch Interaction module. + + A PyTorch implementation of Local Patch Interaction module + as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers + `_ + + Local Patch Interaction module that allows explicit communication between + tokens in 3x3 windows to augment the implicit communication performed by + the block diagonal scatter attention. Implemented using 2 layers of + separable 3x3 convolutions with GeLU and BatchNorm2d + + Args: + in_features (int): The input channels. + out_features (int, optional): The output channels. Defaults to None. + kernel_size (int): The kernel_size in ConvModule. Defaults to 3. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_features: int, + out_features: Optional[int] = None, + kernel_size: int = 3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(LPI, self).__init__(init_cfg=init_cfg) + + out_features = out_features or in_features + padding = kernel_size // 2 + + self.conv1 = ConvModule( + in_channels=in_features, + out_channels=in_features, + kernel_size=kernel_size, + padding=padding, + groups=in_features, + bias=True, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + order=('conv', 'act', 'norm')) + + self.conv2 = ConvModule( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + padding=padding, + groups=out_features, + norm_cfg=None, + act_cfg=None) + + def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: + B, N, C = x.shape + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.conv1(x) + x = self.conv2(x) + x = x.reshape(B, C, N).permute(0, 2, 1) + return x + + +class XCA(BaseModule): + r"""Cross-Covariance Attention module. + + A PyTorch implementation of Cross-Covariance Attention module + as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers + `_ + + In Cross-Covariance Attention (XCA), the channels are updated using a + weighted sum. The weights are obtained from the (softmax normalized) + Cross-covariance matrix :math:`(Q^T \cdot K \in d_h \times d_h)` + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. Defaults to 8. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + proj_drop (float): The drop out rate for linear output weights. + Defaults to 0. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + init_cfg=None): + super(XCA, self).__init__(init_cfg=init_cfg) + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + # (qkv, B, num_heads, channels per head, N) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 4, 1) + q, k, v = qkv.unbind(0) + + # Paper section 3.2 l2-Normalization and temperature scaling + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # (B, num_heads, C', N) -> (B, N, num_heads, C') -> (B, N C) + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class XCABlock(BaseModule): + """Transformer block using XCA. + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. + mlp_ratio (float): The hidden dimension ratio for FFNs. + Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to False. + drop (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The initial value for layer scale. + Defaults to 1. + bn_norm_cfg (dict): Config dict for batchnorm in LPI and + ConvPatchEmbed. Defaults to ``dict(type='BN')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN', eps=1e-6)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + layer_scale_init_value: float = 1., + bn_norm_cfg=dict(type='BN'), + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=None): + super(XCABlock, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, dim) + self.attn = XCA( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.norm3 = build_norm_layer(norm_cfg, dim) + self.local_mp = LPI( + in_features=dim, + norm_cfg=bn_norm_cfg, + act_cfg=act_cfg, + ) + + self.norm2 = build_norm_layer(norm_cfg, dim) + self.ffn = FFN( + embed_dims=dim, + feedforward_channels=int(dim * mlp_ratio), + act_cfg=act_cfg, + ffn_drop=drop, + ) + + self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) + self.gamma3 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) + self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) + + def forward(self, x, H: int, W: int): + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) + # NOTE official code has 3 then 2, so keeping it the same to be + # consistent with loaded weights See + # https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501 + x = x + self.drop_path( + self.gamma3 * self.local_mp(self.norm3(x), H, W)) + x = x + self.drop_path( + self.gamma2 * self.ffn(self.norm2(x), identity=0)) + return x + + +@MODELS.register_module() +class XCiT(BaseBackbone): + """XCiT backbone. + + A PyTorch implementation of XCiT backbone introduced by: + `XCiT: Cross-Covariance Image Transformers + `_ + + Args: + img_size (int, tuple): Input image size. Defaults to 224. + patch_size (int): Patch size. Defaults to 16. + in_channels (int): Number of input channels. Defaults to 3. + embed_dims (int): Embedding dimension. Defaults to 768. + depth (int): depth of vision transformer. Defaults to 12. + cls_attn_layers (int): Depth of Class attention layers. + Defaults to 2. + num_heads (int): Number of attention heads. Defaults to 12. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Defaults to 4. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + use_pos_embed (bool): Whether to use positional encoding. + Defaults to True. + layer_scale_init_value (float): The initial value for layer scale. + Defaults to 1. + tokens_norm (bool): Whether to normalize all tokens or just the + cls_token in the CA. Defaults to False. + out_indices (Sequence[int]): Output from which layers. + Defaults to (-1, ). + frozen_stages (int): Layers to be frozen (all param fixed), and 0 + means to freeze the stem stage. Defaults to -1, which means + not freeze any parameters. + bn_norm_cfg (dict): Config dict for the batch norm layers in LPI and + ConvPatchEmbed. Defaults to ``dict(type='BN')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN', eps=1e-6)``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='GELU')``. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + img_size: Union[int, tuple] = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dims: int = 768, + depth: int = 12, + cls_attn_layers: int = 2, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + use_pos_embed: bool = True, + layer_scale_init_value: float = 1., + tokens_norm: bool = False, + out_type: str = 'cls_token', + out_indices: Sequence[int] = (-1, ), + final_norm: bool = True, + frozen_stages: int = -1, + bn_norm_cfg=dict(type='BN'), + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + init_cfg=dict(type='TruncNormal', layer='Linear')): + super(XCiT, self).__init__(init_cfg=init_cfg) + + img_size = to_2tuple(img_size) + if (img_size[0] % patch_size != 0) or (img_size[1] % patch_size != 0): + raise ValueError(f'`patch_size` ({patch_size}) should divide ' + f'the image shape ({img_size}) evenly.') + + self.embed_dims = embed_dims + + assert out_type in ('raw', 'featmap', 'avg_featmap', 'cls_token') + self.out_type = out_type + + self.patch_embed = ConvPatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims, + norm_cfg=bn_norm_cfg, + act_cfg=act_cfg, + ) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.use_pos_embed = use_pos_embed + if use_pos_embed: + self.pos_embed = PositionalEncodingFourier(dim=embed_dims) + self.pos_drop = nn.Dropout(p=drop_rate) + + self.xca_layers = nn.ModuleList() + self.ca_layers = nn.ModuleList() + self.num_layers = depth + cls_attn_layers + + for _ in range(depth): + self.xca_layers.append( + XCABlock( + dim=embed_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + bn_norm_cfg=bn_norm_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + layer_scale_init_value=layer_scale_init_value, + )) + + for _ in range(cls_attn_layers): + self.ca_layers.append( + ClassAttentionBlock( + dim=embed_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value, + tokens_norm=tokens_norm, + )) + + if final_norm: + self.norm = build_norm_layer(norm_cfg, embed_dims) + + # Transform out_indices + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + if frozen_stages > self.num_layers + 1: + raise ValueError('frozen_stages must be less than ' + f'{self.num_layers} but get {frozen_stages}') + self.frozen_stages = frozen_stages + + def init_weights(self): + super().init_weights() + + if self.init_cfg is not None and self.init_cfg['type'] == 'Pretrained': + return + + trunc_normal_(self.cls_token, std=.02) + + def _freeze_stages(self): + if self.frozen_stages < 0: + return + + # freeze position embedding + if self.use_pos_embed: + self.pos_embed.eval() + for param in self.pos_embed.parameters(): + param.requires_grad = False + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # set dropout to eval model + self.pos_drop.eval() + # freeze cls_token, only use in self.Clslayers + if self.frozen_stages > len(self.xca_layers): + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages): + if i <= len(self.xca_layers): + m = self.xca_layers[i - 1] + else: + m = self.ca_layers[i - len(self.xca_layers) - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze the last layer norm if all_stages are frozen + if self.frozen_stages == len(self.xca_layers) + len(self.ca_layers): + self.norm.eval() + for param in self.norm.parameters(): + param.requires_grad = False + + def forward(self, x): + outs = [] + B = x.shape[0] + # x is (B, N, C). (Hp, Hw) is the patch resolution + x, (Hp, Wp) = self.patch_embed(x) + + if self.use_pos_embed: + # (B, C, Hp, Wp) -> (B, C, N) -> (B, N, C) + pos_encoding = self.pos_embed(B, Hp, Wp) + x = x + pos_encoding.reshape(B, -1, x.size(1)).permute(0, 2, 1) + x = self.pos_drop(x) + + for i, layer in enumerate(self.xca_layers): + x = layer(x, Hp, Wp) + if i in self.out_indices: + outs.append(self._format_output(x, (Hp, Wp), False)) + + x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) + + for i, layer in enumerate(self.ca_layers): + x = layer(x) + if i == len(self.ca_layers) - 1: + x = self.norm(x) + if i + len(self.xca_layers) in self.out_indices: + outs.append(self._format_output(x, (Hp, Wp), True)) + + return tuple(outs) + + def _format_output(self, x, hw, with_cls_token: bool): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + if not with_cls_token: + raise ValueError( + 'Cannot output cls_token since there is no cls_token.') + return x[:, 0] + + patch_token = x[:, 1:] if with_cls_token else x + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() diff --git a/mmpretrain/models/builder.py b/mmpretrain/models/builder.py new file mode 100644 index 0000000..2ea4e25 --- /dev/null +++ b/mmpretrain/models/builder.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.registry import MODELS + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +CLASSIFIERS = MODELS +RETRIEVER = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_classifier(cfg): + """Build classifier.""" + return CLASSIFIERS.build(cfg) + + +def build_retriever(cfg): + """Build retriever.""" + return RETRIEVER.build(cfg) diff --git a/mmpretrain/models/classifiers/__init__.py b/mmpretrain/models/classifiers/__init__.py new file mode 100644 index 0000000..5fa276f --- /dev/null +++ b/mmpretrain/models/classifiers/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseClassifier +from .hugging_face import HuggingFaceClassifier +from .image import ImageClassifier +from .timm import TimmClassifier + +__all__ = [ + 'BaseClassifier', 'ImageClassifier', 'TimmClassifier', + 'HuggingFaceClassifier' +] diff --git a/mmpretrain/models/classifiers/base.py b/mmpretrain/models/classifiers/base.py new file mode 100644 index 0000000..a65fc21 --- /dev/null +++ b/mmpretrain/models/classifiers/base.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Sequence + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement + + +class BaseClassifier(BaseModel, metaclass=ABCMeta): + """Base class for classifiers. + + Args: + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None, it will use "BaseDataPreprocessor" as type, see + :class:`mmengine.model.BaseDataPreprocessor` for more details. + Defaults to None. + + Attributes: + init_cfg (dict): Initialization config dict. + data_preprocessor (:obj:`mmengine.model.BaseDataPreprocessor`): An + extra data pre-processing module, which processes data from + dataloader to the format accepted by :meth:`forward`. + """ + + def __init__(self, + init_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None): + super(BaseClassifier, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + @property + def with_neck(self) -> bool: + """Whether the classifier has a neck.""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_head(self) -> bool: + """Whether the classifier has a head.""" + return hasattr(self, 'head') and self.head is not None + + @abstractmethod + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`BaseDataElement`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C, ...) + in general. + data_samples (List[BaseDataElement], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of + :obj:`mmengine.BaseDataElement`. + - If ``mode="loss"``, return a dict of tensor. + """ + pass + + def extract_feat(self, inputs: torch.Tensor): + """Extract features from the input tensor with shape (N, C, ...). + + The sub-classes are recommended to implement this method to extract + features from backbone and neck. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + """ + raise NotImplementedError + + def extract_feats(self, multi_inputs: Sequence[torch.Tensor], + **kwargs) -> list: + """Extract features from a sequence of input tensor. + + Args: + multi_inputs (Sequence[torch.Tensor]): A sequence of input + tensor. It can be used in augmented inference. + **kwargs: Other keyword arguments accepted by :meth:`extract_feat`. + + Returns: + list: Features of every input tensor. + """ + assert isinstance(multi_inputs, Sequence), \ + '`extract_feats` is used for a sequence of inputs tensor. If you '\ + 'want to extract on single inputs tensor, use `extract_feat`.' + return [self.extract_feat(inputs, **kwargs) for inputs in multi_inputs] diff --git a/mmpretrain/models/classifiers/hugging_face.py b/mmpretrain/models/classifiers/hugging_face.py new file mode 100644 index 0000000..26a8fda --- /dev/null +++ b/mmpretrain/models/classifiers/hugging_face.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All right reserved. +import re +from collections import OrderedDict +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from mmpretrain.utils import require +from .base import BaseClassifier + + +@MODELS.register_module() +class HuggingFaceClassifier(BaseClassifier): + """Image classifiers for HuggingFace model. + + This class accepts all positional and keyword arguments of the API + ``from_pretrained`` (when ``pretrained=True``) and ``from_config`` (when + ``pretrained=False``) of `transformers.AutoModelForImageClassification`_ + and use it to create a model from hugging-face. + + It can load checkpoints of hugging-face directly, and the saved checkpoints + also can be directly load by hugging-face. + + Please confirm that you have installed ``transfromers`` if you want to use it. + + .. _transformers.AutoModelForImageClassification: + https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification + + Args: + model_name (str): The name of the model to use in hugging-face. + pretrained (bool): Whether to load pretrained checkpoint from + hugging-face. Defaults to False. + *args: Other positional arguments of the method + `from_pretrained` or `from_config`. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in :mod:`mmpretrain.model.utils.augment`. + + Defaults to None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + **kwargs: Other keyword arguments of the method + `from_pretrained` or `from_config`. + + Examples: + >>> import torch + >>> from mmpretrain.models import build_classifier + >>> cfg = dict(type='HuggingFaceClassifier', model_name='microsoft/resnet-50', pretrained=True) + >>> model = build_classifier(cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> out = model(inputs) + >>> print(out.shape) + torch.Size([1, 1000]) + """ # noqa: E501 + + @require('transformers') + def __init__(self, + model_name, + pretrained=False, + *model_args, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + train_cfg: Optional[dict] = None, + with_cp: bool = False, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + **kwargs): + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor') + + if train_cfg is not None and 'augments' in train_cfg: + # Set batch augmentations by `train_cfg` + data_preprocessor['batch_augments'] = train_cfg + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + from transformers import AutoConfig, AutoModelForImageClassification + if pretrained: + self.model = AutoModelForImageClassification.from_pretrained( + model_name, *model_args, **kwargs) + else: + config = AutoConfig.from_pretrained(model_name, *model_args, + **kwargs) + self.model = AutoModelForImageClassification.from_config(config) + + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + self.with_cp = with_cp + if self.with_cp: + self.model.gradient_checkpointing_enable() + + self._register_state_dict_hook(self._remove_state_dict_prefix) + self._register_load_state_dict_pre_hook(self._add_state_dict_prefix) + + def forward(self, inputs, data_samples=None, mode='tensor'): + if mode == 'tensor': + return self.model(inputs).logits + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs: torch.Tensor): + raise NotImplementedError( + "The HuggingFaceClassifier doesn't support extract feature yet.") + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs): + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments of the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self.model(inputs).logits + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + return losses + + def predict(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None): + """Predict results from a batch of inputs. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + + Returns: + List[DataSample]: The prediction results. + """ + # The part can be traced by torch.fx + cls_score = self.model(inputs).logits + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score, data_samples): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = F.softmax(cls_score, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + if data_samples is not None: + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + data_sample.set_pred_score(score).set_pred_label(label) + else: + data_samples = [] + for score, label in zip(pred_scores, pred_labels): + data_samples.append( + DataSample().set_pred_score(score).set_pred_label(label)) + + return data_samples + + @staticmethod + def _remove_state_dict_prefix(self, state_dict, prefix, local_metadata): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_key = re.sub(f'^{prefix}model.', prefix, k) + new_state_dict[new_key] = v + return new_state_dict + + @staticmethod + def _add_state_dict_prefix(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + new_prefix = prefix + 'model.' + for k in list(state_dict.keys()): + new_key = re.sub(f'^{prefix}', new_prefix, k) + state_dict[new_key] = state_dict[k] + del state_dict[k] diff --git a/mmpretrain/models/classifiers/image.py b/mmpretrain/models/classifiers/image.py new file mode 100644 index 0000000..6d0edd7 --- /dev/null +++ b/mmpretrain/models/classifiers/image.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseClassifier + + +@MODELS.register_module() +class ImageClassifier(BaseClassifier): + """Image classifiers for supervised classification task. + + Args: + backbone (dict): The backbone module. See + :mod:`mmpretrain.models.backbones`. + neck (dict, optional): The neck module to process features from + backbone. See :mod:`mmpretrain.models.necks`. Defaults to None. + head (dict, optional): The head module to do prediction and calculate + loss from processed features. See :mod:`mmpretrain.models.heads`. + Notice that if the head is not set, almost all methods cannot be + used except :meth:`extract_feat`. Defaults to None. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in + :mod:`mmpretrain.model.utils.augment`. + - probs (List[float], optional): The probability of every batch + augmentation methods. If None, choose evenly. Defaults to None. + + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + pretrained: Optional[str] = None, + train_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if pretrained is not None: + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + data_preprocessor = data_preprocessor or {} + + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'ClsDataPreprocessor') + data_preprocessor.setdefault('batch_augments', train_cfg) + data_preprocessor = MODELS.build(data_preprocessor) + elif not isinstance(data_preprocessor, nn.Module): + raise TypeError('data_preprocessor should be a `dict` or ' + f'`nn.Module` instance, but got ' + f'{type(data_preprocessor)}') + + super(ImageClassifier, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if not isinstance(backbone, nn.Module): + backbone = MODELS.build(backbone) + if neck is not None and not isinstance(neck, nn.Module): + neck = MODELS.build(neck) + if head is not None and not isinstance(head, nn.Module): + head = MODELS.build(head) + + self.backbone = backbone + self.neck = neck + self.head = head + + # If the model needs to load pretrain weights from a third party, + # the key can be modified with this hook + if hasattr(self.backbone, '_checkpoint_filter'): + self._register_load_state_dict_pre_hook( + self.backbone._checkpoint_filter) + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor(s) without any + post-processing, same as a common PyTorch Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + feats = self.extract_feat(inputs) + return self.head(feats) if self.with_head else feats + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs, stage='neck'): + """Extract features from the input tensor with shape (N, C, ...). + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + stage (str): Which stage to output the feature. Choose from: + + - "backbone": The output of backbone network. Returns a tuple + including multiple stages features. + - "neck": The output of neck module. Returns a tuple including + multiple stages features. + - "pre_logits": The feature before the final classification + linear layer. Usually returns a tensor. + + Defaults to "neck". + + Returns: + tuple | Tensor: The output of specified stage. + The output depends on detailed implementation. In general, the + output of backbone and neck is a tuple and the output of + pre_logits is a tensor. + + Examples: + 1. Backbone output + + >>> import torch + >>> from mmengine import Config + >>> from mmpretrain.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model + >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps + >>> model = build_classifier(cfg) + >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone') + >>> for out in outs: + ... print(out.shape) + torch.Size([1, 64, 56, 56]) + torch.Size([1, 128, 28, 28]) + torch.Size([1, 256, 14, 14]) + torch.Size([1, 512, 7, 7]) + + 2. Neck output + + >>> import torch + >>> from mmengine import Config + >>> from mmpretrain.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model + >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps + >>> model = build_classifier(cfg) + >>> + >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck') + >>> for out in outs: + ... print(out.shape) + torch.Size([1, 64]) + torch.Size([1, 128]) + torch.Size([1, 256]) + torch.Size([1, 512]) + + 3. Pre-logits output (without the final linear classifier head) + + >>> import torch + >>> from mmengine import Config + >>> from mmpretrain.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py').model + >>> model = build_classifier(cfg) + >>> + >>> out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits') + >>> print(out.shape) # The hidden dims in head is 3072 + torch.Size([1, 3072]) + """ # noqa: E501 + assert stage in ['backbone', 'neck', 'pre_logits'], \ + (f'Invalid output stage "{stage}", please choose from "backbone", ' + '"neck" and "pre_logits"') + + x = self.backbone(inputs) + + if stage == 'backbone': + return x + + if self.with_neck: + x = self.neck(x) + if stage == 'neck': + return x + + assert self.with_head and hasattr(self.head, 'pre_logits'), \ + "No head or the head doesn't implement `pre_logits` method." + return self.head.pre_logits(x) + + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + feats = self.extract_feat(inputs) + return self.head.loss(feats, data_samples) + + def predict(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + **kwargs) -> List[DataSample]: + """Predict results from a batch of inputs. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + """ + feats = self.extract_feat(inputs) + return self.head.predict(feats, data_samples, **kwargs) + + def get_layer_depth(self, param_name: str): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + + Returns: + Tuple[int, int]: The layer-wise depth and the max depth. + """ + if hasattr(self.backbone, 'get_layer_depth'): + return self.backbone.get_layer_depth(param_name, 'backbone.') + else: + raise NotImplementedError( + f"The backbone {type(self.backbone)} doesn't " + 'support `get_layer_depth` by now.') diff --git a/mmpretrain/models/classifiers/timm.py b/mmpretrain/models/classifiers/timm.py new file mode 100644 index 0000000..d777b2e --- /dev/null +++ b/mmpretrain/models/classifiers/timm.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All right reserved. +import re +from collections import OrderedDict +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from mmpretrain.utils import require +from .base import BaseClassifier + + +@MODELS.register_module() +class TimmClassifier(BaseClassifier): + """Image classifiers for pytorch-image-models (timm) model. + + This class accepts all positional and keyword arguments of the function + `timm.models.create_model `_ and use + it to create a model from pytorch-image-models. + + It can load checkpoints of timm directly, and the saved checkpoints also + can be directly load by timm. + + Please confirm that you have installed ``timm`` if you want to use it. + + Args: + *args: All positional arguments of the function + `timm.models.create_model`. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in :mod:`mmpretrain.model.utils.augment`. + + Defaults to None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + **kwargs: Other keyword arguments of the function + `timm.models.create_model`. + + Examples: + >>> import torch + >>> from mmpretrain.models import build_classifier + >>> cfg = dict(type='TimmClassifier', model_name='resnet50', pretrained=True) + >>> model = build_classifier(cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> out = model(inputs) + >>> print(out.shape) + torch.Size([1, 1000]) + """ # noqa: E501 + + @require('timm') + def __init__(self, + *args, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + train_cfg: Optional[dict] = None, + with_cp: bool = False, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + **kwargs): + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor') + + if train_cfg is not None and 'augments' in train_cfg: + # Set batch augmentations by `train_cfg` + data_preprocessor['batch_augments'] = train_cfg + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + from timm.models import create_model + self.model = create_model(*args, **kwargs) + + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + self.with_cp = with_cp + if self.with_cp: + self.model.set_grad_checkpointing() + + self._register_state_dict_hook(self._remove_state_dict_prefix) + self._register_load_state_dict_pre_hook(self._add_state_dict_prefix) + + def forward(self, inputs, data_samples=None, mode='tensor'): + if mode == 'tensor': + return self.model(inputs) + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs: torch.Tensor): + if hasattr(self.model, 'forward_features'): + return self.model.forward_features(inputs) + else: + raise NotImplementedError( + f"The model {type(self.model)} doesn't support extract " + "feature because it don't have `forward_features` method.") + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs): + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments of the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self.model(inputs) + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = self.loss_module(cls_score, target, **kwargs) + losses['loss'] = loss + + return losses + + def predict(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None): + """Predict results from a batch of inputs. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + + Returns: + List[DataSample]: The prediction results. + """ + # The part can be traced by torch.fx + cls_score = self(inputs) + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score, data_samples=None): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = F.softmax(cls_score, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + if data_samples is not None: + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + data_sample.set_pred_score(score).set_pred_label(label) + else: + data_samples = [] + for score, label in zip(pred_scores, pred_labels): + data_samples.append( + DataSample().set_pred_score(score).set_pred_label(label)) + + return data_samples + + @staticmethod + def _remove_state_dict_prefix(self, state_dict, prefix, local_metadata): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_key = re.sub(f'^{prefix}model.', prefix, k) + new_state_dict[new_key] = v + return new_state_dict + + @staticmethod + def _add_state_dict_prefix(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + new_prefix = prefix + 'model.' + for k in list(state_dict.keys()): + new_key = re.sub(f'^{prefix}', new_prefix, k) + state_dict[new_key] = state_dict[k] + del state_dict[k] diff --git a/mmpretrain/models/heads/__init__.py b/mmpretrain/models/heads/__init__.py new file mode 100644 index 0000000..4364fb5 --- /dev/null +++ b/mmpretrain/models/heads/__init__.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .beitv1_head import BEiTV1Head +from .beitv2_head import BEiTV2Head +from .cae_head import CAEHead +from .cls_head import ClsHead +from .conformer_head import ConformerHead +from .contrastive_head import ContrastiveHead +from .deit_head import DeiTClsHead +from .efficientformer_head import EfficientFormerClsHead +from .grounding_head import GroundingHead +from .itc_head import ITCHead +from .itm_head import ITMHead +from .itpn_clip_head import iTPNClipHead +from .latent_heads import LatentCrossCorrelationHead, LatentPredictHead +from .levit_head import LeViTClsHead +from .linear_head import LinearClsHead +from .mae_head import MAEPretrainHead +from .margin_head import ArcFaceClsHead +from .mim_head import MIMHead +from .mixmim_head import MixMIMPretrainHead +from .mocov3_head import MoCoV3Head +from .multi_label_cls_head import MultiLabelClsHead +from .multi_label_csra_head import CSRAClsHead +from .multi_label_linear_head import MultiLabelLinearClsHead +from .multi_task_head import MultiTaskHead +from .seq_gen_head import SeqGenerationHead +from .simmim_head import SimMIMHead +from .spark_head import SparKPretrainHead +from .stacked_head import StackedLinearClsHead +from .swav_head import SwAVHead +from .vig_head import VigClsHead +from .vision_transformer_head import VisionTransformerClsHead +from .vqa_head import VQAGenerationHead + +__all__ = [ + 'ClsHead', + 'LinearClsHead', + 'StackedLinearClsHead', + 'MultiLabelClsHead', + 'MultiLabelLinearClsHead', + 'VisionTransformerClsHead', + 'DeiTClsHead', + 'ConformerHead', + 'EfficientFormerClsHead', + 'ArcFaceClsHead', + 'CSRAClsHead', + 'MultiTaskHead', + 'LeViTClsHead', + 'VigClsHead', + 'BEiTV1Head', + 'BEiTV2Head', + 'CAEHead', + 'ContrastiveHead', + 'LatentCrossCorrelationHead', + 'LatentPredictHead', + 'MAEPretrainHead', + 'MixMIMPretrainHead', + 'SwAVHead', + 'MoCoV3Head', + 'MIMHead', + 'SimMIMHead', + 'SeqGenerationHead', + 'VQAGenerationHead', + 'ITCHead', + 'ITMHead', + 'GroundingHead', + 'iTPNClipHead', + 'SparKPretrainHead', +] diff --git a/mmpretrain/models/heads/beitv1_head.py b/mmpretrain/models/heads/beitv1_head.py new file mode 100644 index 0000000..df422ea --- /dev/null +++ b/mmpretrain/models/heads/beitv1_head.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class BEiTV1Head(BaseModule): + """Head for BEiT v1 Pre-training. + + Compute the logits and the cross entropy loss. + + Args: + embed_dims (int): The dimension of embedding. + num_embed (int): The number of classification types. + loss (dict): The config of loss. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_embed: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.cls_head = nn.Linear(embed_dims, num_embed) + self.loss_module = MODELS.build(loss) + + def loss(self, feats: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + feats (torch.Tensor): Features from backbone. + target (torch.Tensor): Target generated by target_generator. + mask (torch.Tensor): Generated mask for pretraing. + """ + mask = mask.flatten(1).to(torch.bool) + target = torch.argmax(target, dim=1).flatten(1) + target = target[mask] + + # remove cls_token + feats = feats[:, 1:] + logits = self.cls_head(feats[mask]) + + loss = self.loss_module(logits, target) + return loss diff --git a/mmpretrain/models/heads/beitv2_head.py b/mmpretrain/models/heads/beitv2_head.py new file mode 100644 index 0000000..cf677a2 --- /dev/null +++ b/mmpretrain/models/heads/beitv2_head.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class BEiTV2Head(BaseModule): + """Head for BEiT v2 Pre-training. + + Compute the logits and the cross entropy loss. + + Args: + embed_dims (int): The dimension of embedding. + num_embed (int): The number of classification types. + loss (dict): The config of loss. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_embed: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.cls_head = nn.Linear(embed_dims, num_embed) + self.loss_module = MODELS.build(loss) + + def loss(self, feats: torch.Tensor, feats_cls_pt: torch.Tensor, + target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + feats (torch.Tensor): Features from backbone. + feats_cls_pt (torch.Tensor) : Features from class late layers for + pretraining. + target (torch.Tensor): Target generated by target_generator. + mask (torch.Tensor): Generated mask for pretraing. + """ + mask = mask.flatten(1).to(torch.bool) + target = target[mask] + + # shared cls head + logits = self.cls_head(feats[mask]) + logits_cls_pt = self.cls_head(feats_cls_pt[mask]) + + loss_1 = self.loss_module(logits, target) + loss_2 = self.loss_module(logits_cls_pt, target) + return loss_1, loss_2 diff --git a/mmpretrain/models/heads/cae_head.py b/mmpretrain/models/heads/cae_head.py new file mode 100644 index 0000000..18a07f0 --- /dev/null +++ b/mmpretrain/models/heads/cae_head.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CAEHead(BaseModule): + """Head for CAE Pre-training. + + Compute the align loss and the main loss. In addition, this head also + generates the prediction target generated by dalle. + + Args: + loss (dict): The config of loss. + tokenizer_path (str): The path of the tokenizer. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.loss_module = MODELS.build(loss) + + @torch.no_grad() + def _generate_target(self, logits_target: torch.Tensor) -> torch.Tensor: + """Generate the reconstruction target. + + Args: + logits_target (torch.Tensor): The logits generated by DALL-E.s + + Returns: + torch.Tensor: The logits target. + """ + target = torch.argmax(logits_target, dim=1) + return target.flatten(1) + + def loss(self, logits: torch.Tensor, logits_target: torch.Tensor, + latent_pred: torch.Tensor, latent_target: torch.Tensor, + mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate loss. + + Args: + logits (torch.Tensor): Logits generated by decoder. + logits_target (img_target): Target generated by dalle for decoder + prediction. + latent_pred (torch.Tensor): Latent prediction by regressor. + latent_target (torch.Tensor): Target for latent prediction, + generated by teacher. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The tuple of loss. + - ``loss_main`` (torch.Tensor): Cross entropy loss. + - ``loss_align`` (torch.Tensor): MSE loss. + """ + + target = self._generate_target(logits_target) # target features + target = target[mask].detach() + + # loss main for decoder, loss align for regressor + loss_main, loss_align = self.loss_module(logits, target, latent_pred, + latent_target) + + return (loss_main, loss_align) diff --git a/mmpretrain/models/heads/cls_head.py b/mmpretrain/models/heads/cls_head.py new file mode 100644 index 0000000..4ac4c51 --- /dev/null +++ b/mmpretrain/models/heads/cls_head.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.evaluation.metrics import Accuracy +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class ClsHead(BaseModule): + """Classification head. + + Args: + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``. + cal_acc (bool): Whether to calculate accuracy during training. + If you use batch augmentations like Mixup and CutMix during + training, it is pointless to calculate accuracy. + Defaults to False. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), + topk: Union[int, Tuple[int]] = (1, ), + cal_acc: bool = False, + init_cfg: Optional[dict] = None): + super(ClsHead, self).__init__(init_cfg=init_cfg) + + self.topk = topk + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + self.cal_acc = cal_acc + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``ClsHead``, we just obtain the feature + of the last stage. + """ + # The ClsHead doesn't have other module, just return after unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The ClsHead doesn't have the final classification head, + # just return the unpacked inputs. + return pre_logits + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + # compute accuracy + if self.cal_acc: + assert target.ndim == 1, 'If you enable batch augmentation ' \ + 'like mixup during training, `cal_acc` is pointless.' + acc = Accuracy.calculate(cls_score, target, topk=self.topk) + losses.update( + {f'accuracy_top-{k}': a + for k, a in zip(self.topk, acc)}) + + return losses + + def predict( + self, + feats: Tuple[torch.Tensor], + data_samples: Optional[List[Optional[DataSample]]] = None + ) -> List[DataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample | None], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[DataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score, data_samples): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = F.softmax(cls_score, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = DataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples diff --git a/mmpretrain/models/heads/conformer_head.py b/mmpretrain/models/heads/conformer_head.py new file mode 100644 index 0000000..eade90d --- /dev/null +++ b/mmpretrain/models/heads/conformer_head.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.evaluation.metrics import Accuracy +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .cls_head import ClsHead + + +@MODELS.register_module() +class ConformerHead(ClsHead): + """Linear classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (Sequence[int]): Number of channels in the input + feature map. + init_cfg (dict | optional): The extra init config of layers. + Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + + def __init__( + self, + num_classes: int, + in_channels: Sequence[int], # [conv_dim, trans_dim] + init_cfg: dict = dict(type='TruncNormal', layer='Linear', std=.02), + **kwargs): + super(ConformerHead, self).__init__(init_cfg=init_cfg, **kwargs) + + self.in_channels = in_channels + self.num_classes = num_classes + self.init_cfg = init_cfg + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self.conv_cls_head = nn.Linear(self.in_channels[0], num_classes) + self.trans_cls_head = nn.Linear(self.in_channels[1], num_classes) + + def pre_logits(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``ConformerHead``, we just obtain the + feature of the last stage. + """ + # The ConformerHead doesn't have other module, + # just return after unpacking. + return feats[-1] + + def forward(self, feats: Tuple[List[torch.Tensor]]) -> Tuple[torch.Tensor]: + """The forward process.""" + x = self.pre_logits(feats) + # There are two outputs in the Conformer model + assert len(x) == 2 + + conv_cls_score = self.conv_cls_head(x[0]) + tran_cls_score = self.trans_cls_head(x[1]) + + return conv_cls_score, tran_cls_score + + def predict(self, + feats: Tuple[List[torch.Tensor]], + data_samples: List[DataSample] = None) -> List[DataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[DataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + conv_cls_score, tran_cls_score = self(feats) + cls_score = conv_cls_score + tran_cls_score + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_loss(self, cls_score: Tuple[torch.Tensor], + data_samples: List[DataSample], **kwargs) -> dict: + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = torch.cat([i.gt_label for i in data_samples]) + + # compute loss + losses = dict() + loss = sum([ + self.loss_module( + score, target, avg_factor=score.size(0), **kwargs) + for score in cls_score + ]) + losses['loss'] = loss + + # compute accuracy + if self.cal_acc: + assert target.ndim == 1, 'If you enable batch augmentation ' \ + 'like mixup during training, `cal_acc` is pointless.' + acc = Accuracy.calculate( + cls_score[0] + cls_score[1], target, topk=self.topk) + losses.update( + {f'accuracy_top-{k}': a + for k, a in zip(self.topk, acc)}) + + return losses diff --git a/mmpretrain/models/heads/contrastive_head.py b/mmpretrain/models/heads/contrastive_head.py new file mode 100644 index 0000000..6d1474a --- /dev/null +++ b/mmpretrain/models/heads/contrastive_head.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class ContrastiveHead(BaseModule): + """Head for contrastive learning. + + The contrastive loss is implemented in this head and is used in SimCLR, + MoCo, DenseCL, etc. + + Args: + loss (dict): Config dict for module of loss functions. + temperature (float): The temperature hyper-parameter that + controls the concentration level of the distribution. + Defaults to 0.1. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + loss: dict, + temperature: float = 0.1, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.loss_module = MODELS.build(loss) + self.temperature = temperature + + def loss(self, pos: torch.Tensor, neg: torch.Tensor) -> torch.Tensor: + """Forward function to compute contrastive loss. + + Args: + pos (torch.Tensor): Nx1 positive similarity. + neg (torch.Tensor): Nxk negative similarity. + + Returns: + torch.Tensor: The contrastive loss. + """ + N = pos.size(0) + logits = torch.cat((pos, neg), dim=1) + logits /= self.temperature + labels = torch.zeros((N, ), dtype=torch.long).to(pos.device) + + loss = self.loss_module(logits, labels) + return loss diff --git a/mmpretrain/models/heads/deit_head.py b/mmpretrain/models/heads/deit_head.py new file mode 100644 index 0000000..a96f6e1 --- /dev/null +++ b/mmpretrain/models/heads/deit_head.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import List, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .vision_transformer_head import VisionTransformerClsHead + + +@MODELS.register_module() +class DeiTClsHead(VisionTransformerClsHead): + """Distilled Vision Transformer classifier head. + + Comparing with the :class:`VisionTransformerClsHead`, this head adds an + extra linear layer to handle the dist token. The final classification score + is the average of both linear transformation results of ``cls_token`` and + ``dist_token``. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + hidden_dim (int, optional): Number of the dimensions for hidden layer. + Defaults to None, which means no extra hidden layer. + act_cfg (dict): The activation config. Only available during + pre-training. Defaults to ``dict(type='Tanh')``. + init_cfg (dict): The extra initialization configs. Defaults to + ``dict(type='Constant', layer='Linear', val=0)``. + """ + + def _init_layers(self): + """"Init extra hidden linear layer to handle dist token if exists.""" + super(DeiTClsHead, self)._init_layers() + if self.hidden_dim is None: + head_dist = nn.Linear(self.in_channels, self.num_classes) + else: + head_dist = nn.Linear(self.hidden_dim, self.num_classes) + self.layers.add_module('head_dist', head_dist) + + def pre_logits(self, + feats: Tuple[List[torch.Tensor]]) -> Tuple[torch.Tensor]: + """The process before the final classification head. + + The input ``feats`` is a tuple of list of tensor, and each tensor is + the feature of a backbone stage. In ``DeiTClsHead``, we obtain the + feature of the last stage and forward in hidden layer if exists. + """ + feat = feats[-1] # Obtain feature of the last scale. + # For backward-compatibility with the previous ViT output + if len(feat) == 3: + _, cls_token, dist_token = feat + else: + cls_token, dist_token = feat + if self.hidden_dim is None: + return cls_token, dist_token + else: + cls_token = self.layers.act(self.layers.pre_logits(cls_token)) + dist_token = self.layers.act(self.layers.pre_logits(dist_token)) + return cls_token, dist_token + + def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The forward process.""" + if self.training: + warnings.warn('MMPretrain cannot train the ' + 'distilled version DeiT.') + cls_token, dist_token = self.pre_logits(feats) + # The final classification head. + cls_score = (self.layers.head(cls_token) + + self.layers.head_dist(dist_token)) / 2 + return cls_score diff --git a/mmpretrain/models/heads/efficientformer_head.py b/mmpretrain/models/heads/efficientformer_head.py new file mode 100644 index 0000000..09aa05b --- /dev/null +++ b/mmpretrain/models/heads/efficientformer_head.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .cls_head import ClsHead + + +@MODELS.register_module() +class EfficientFormerClsHead(ClsHead): + """EfficientFormer classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + distillation (bool): Whether use a additional distilled head. + Defaults to True. + init_cfg (dict): The extra initialization configs. Defaults to + ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + + def __init__(self, + num_classes, + in_channels, + distillation=True, + init_cfg=dict(type='Normal', layer='Linear', std=0.01), + *args, + **kwargs): + super(EfficientFormerClsHead, self).__init__( + init_cfg=init_cfg, *args, **kwargs) + self.in_channels = in_channels + self.num_classes = num_classes + self.dist = distillation + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self.head = nn.Linear(self.in_channels, self.num_classes) + if self.dist: + self.dist_head = nn.Linear(self.in_channels, self.num_classes) + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.head(pre_logits) + + if self.dist: + cls_score = (cls_score + self.dist_head(pre_logits)) / 2 + return cls_score + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In :obj`EfficientFormerClsHead`, we just + obtain the feature of the last stage. + """ + # The EfficientFormerClsHead doesn't have other module, just return + # after unpacking. + return feats[-1] + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + if self.dist: + raise NotImplementedError( + "MMPretrain doesn't support to train" + ' the distilled version EfficientFormer.') + else: + return super().loss(feats, data_samples, **kwargs) diff --git a/mmpretrain/models/heads/grounding_head.py b/mmpretrain/models/heads/grounding_head.py new file mode 100644 index 0000000..a47512e --- /dev/null +++ b/mmpretrain/models/heads/grounding_head.py @@ -0,0 +1,217 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.models.utils.box_utils import (box_cxcywh_to_xyxy, + generalized_box_iou) +from mmpretrain.registry import MODELS, TOKENIZER + + +@MODELS.register_module() +class GroundingHead(BaseModule): + """bbox Coordination generation head for multi-modal pre-trained task, + adapted by BLIP. Normally used for visual grounding. + + Args: + loss: dict, + decoder: dict, + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__( + self, + decoder: dict = None, + tokenizer: dict = None, + box_l1_loss_coeff=4.0, + box_giou_loss_coeff=2.0, + init_cfg: Optional[dict] = None, + ) -> None: + super(GroundingHead, self).__init__(init_cfg=init_cfg) + ''' init the decoder from med_config''' + self.decoder = None + if decoder: + self.decoder = MODELS.build(decoder) + self.loss_fn = torch.nn.CrossEntropyLoss( + reduction='none', ignore_index=-100) + + self.box_l1_loss_coeff = box_l1_loss_coeff + self.box_giou_loss_coeff = box_giou_loss_coeff + + if isinstance(tokenizer, dict): + self.tokenizer = TOKENIZER.build(tokenizer) + else: + self.tokenizer = tokenizer + + self.image_res = 640 + prefix_ids = torch.tensor( + self.tokenizer.convert_tokens_to_ids(['[unused339]'])) + target_ids = torch.tensor( + self.tokenizer.convert_tokens_to_ids( + [f'[unused{340+_}]' for _ in range(self.image_res + 1)])) + self.register_buffer('prefix_ids', prefix_ids) + self.register_buffer('target_ids', target_ids) + + bbox_prob_mask = torch.zeros(len(self.tokenizer)) + bbox_prob_mask[self.target_ids[0]:self.target_ids[-1] + 1] = 1 + bbox_prob_mask = (1.0 - bbox_prob_mask) * -10000.0 + self.register_buffer('bbox_prob_mask', bbox_prob_mask) + self.bin_start_idx = self.target_ids[0] + + def forward(self, text_embedding, text_embedding_mask, + encoder_hidden_states, encoder_attention_mask): + + # localize prompt token, text embedding + + merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], + 1) + merge_att_mask = torch.cat( + [encoder_attention_mask, text_embedding_mask], 1) + + loc_prompt = self.prompt.weight.T + loc_prompt = torch.repeat_interleave(loc_prompt, + merge_att_mask.shape[0], + 0).unsqueeze(1) + + loc_prompt_mask = torch.ones(loc_prompt.shape[:-1]).long().to( + loc_prompt.device) + + decoder_out = self.decoder( + inputs_embeds=loc_prompt, + attention_mask=loc_prompt_mask, + encoder_hidden_states=merged_encode_hs, + encoder_attention_mask=merge_att_mask, + output_hidden_states=True, + labels=None, + ) + decoder_hs = decoder_out.hidden_states[-1][:, 0, :] + box_pred = self.box_head(decoder_hs) + return decoder_out, decoder_hs, box_pred + + def loss(self, + text_embedding, + text_embedding_mask, + encoder_hidden_states, + encoder_attention_mask, + decoder_targets, + return_scores=False): + """Calculate losses from the extracted features. + + Args: + feats (dict): The features extracted from the backbone. + data_samples (List[BaseDataElement]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], + 1) + merge_att_mask = torch.cat( + [encoder_attention_mask, text_embedding_mask], 1) + + answer_targets = (decoder_targets * + self.image_res).long() + self.bin_start_idx + prefix_ids = torch.repeat_interleave(self.prefix_ids, + merge_att_mask.shape[0], + 0).unsqueeze(-1) + prefix_ids = torch.cat([prefix_ids, answer_targets], dim=1) + + answer_output = self.decoder( + prefix_ids, + encoder_hidden_states=merged_encode_hs, + encoder_attention_mask=merge_att_mask, + labels=None, + return_dict=True, + ) + prob_mask = self.bbox_prob_mask.view(1, 1, + self.bbox_prob_mask.shape[-1]) + prediction_scores = answer_output.logits + prob_mask + + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = prefix_ids[:, 1:].contiguous() + vocab_size = len(self.tokenizer) + loss_seq_init = self.loss_fn( + shifted_prediction_scores.view(-1, vocab_size), labels.view(-1)) + + with torch.no_grad(): + pred_box = (torch.argmax( + prediction_scores[:, :-1, :].contiguous(), dim=-1) - + self.bin_start_idx) / self.image_res + weight_bbox = F.l1_loss( + pred_box, decoder_targets, reduction='none').clamp( + 0, 5) * self.box_l1_loss_coeff + weight_giou = (1 - torch.diag( + generalized_box_iou( + box_cxcywh_to_xyxy(pred_box), + box_cxcywh_to_xyxy(decoder_targets))) + ) * self.box_giou_loss_coeff + bs = text_embedding.shape[0] + loss_seq = loss_seq_init[:].view(bs, -1, 4) + loss_seq = loss_seq * weight_bbox + loss_seq = loss_seq * weight_giou.unsqueeze(1) + + loss_seq = loss_seq.mean() + + losses = { + 'loss_seq': loss_seq, + 'loss_seq_init': loss_seq_init.mean(), + 'loss': loss_seq, + 'box_l1': weight_bbox.mean(-1).mean().detach(), + 'box_giou': weight_giou.mean().detach() + } + + return losses + + def predict( + self, + text_embedding, + text_embedding_mask, + encoder_hidden_states, + encoder_attention_mask, + ): + """Generates the bbox coordinates at inference time.""" + + merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], + 1) + merge_att_mask = torch.cat( + [encoder_attention_mask, text_embedding_mask], 1) + + prefix_ids = torch.repeat_interleave(self.prefix_ids, + merge_att_mask.shape[0], + 0).unsqueeze(-1) + + for _ in range(4): + decoder_output = self.decoder( + prefix_ids, + encoder_hidden_states=merged_encode_hs, + encoder_attention_mask=merge_att_mask, + labels=None, + return_dict=True, + ) + prob_mask = self.bbox_prob_mask.view(1, 1, + self.bbox_prob_mask.shape[-1]) + prediction_scores = decoder_output.logits + prob_mask + + prefix_ids = torch.cat([ + prefix_ids, + torch.argmax(prediction_scores[:, -1, :], dim=-1).unsqueeze(1) + ], + dim=1) + + pred_box = self.process_bbox(prefix_ids[:, 1:]) # xywh 0-1 to xyxy 0-1 + + return pred_box + + @torch.no_grad() + def process_bbox(self, bbox): + bbox = bbox - self.bin_start_idx + bbox = torch.true_divide(bbox, self.image_res) + bbox = box_cxcywh_to_xyxy(bbox) + bbox = torch.clip(bbox, 0, 1) + assert torch.all(bbox <= 1) + return bbox diff --git a/mmpretrain/models/heads/itc_head.py b/mmpretrain/models/heads/itc_head.py new file mode 100644 index 0000000..006d52c --- /dev/null +++ b/mmpretrain/models/heads/itc_head.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.dist import all_gather +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class ITCHead(BaseModule): + """Image-text matching head for multi-modal pre-trained task. Adapted by + BLIP, ALBEF. Normally used for retrieval task. + + Args: + embed_dim (int): Embed channel size for queue. + queue_size (int): Queue size for image and text. Defaults to 57600. + temperature (float): Temperature to calculate the similarity. + Defaults to 0.07. + use_distill (bool): Whether to use distill to calculate loss. + Defaults to True. + alpha (float): Weight for momentum similarity. Defaults to 0.4. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + embed_dim: int, + queue_size: int = 57600, + temperature: float = 0.07, + use_distill: bool = True, + alpha: float = 0.4, + init_cfg: Optional[dict] = None): + super(ITCHead, self).__init__(init_cfg=init_cfg) + self.temp = nn.Parameter(temperature * torch.ones([])) + self.use_distill = use_distill + if self.use_distill: + # create the queue + self.register_buffer('image_queue', + torch.randn(embed_dim, queue_size)) + self.register_buffer('text_queue', + torch.randn(embed_dim, queue_size)) + self.register_buffer('idx_queue', torch.full((1, queue_size), + -100)) + self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) + + self.image_queue = F.normalize(self.image_queue, dim=0) + self.text_queue = F.normalize(self.text_queue, dim=0) + + self.queue_size = queue_size + # This value will be warmup by `WarmupParamHook` + self.alpha = alpha + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + return feats[-1] + + def loss(self, feats: Tuple[torch.Tensor], data_samples, **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[ClsDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + # The part can be traced by torch.fx + img_feats, text_feats, img_feats_m, text_feats_m = self(feats) + + img_feats_all = torch.cat( + [img_feats_m.t(), + self.image_queue.clone().detach()], dim=1) + text_feats_all = torch.cat( + [text_feats_m.t(), + self.text_queue.clone().detach()], dim=1) + + # The part can not be traced by torch.fx + losses = self._get_loss(img_feats, text_feats, img_feats_m, + text_feats_m, img_feats_all, text_feats_all, + data_samples, **kwargs) + return losses + + def _get_loss(self, img_feats, text_feats, img_feats_m, text_feats_m, + img_feats_all, text_feats_all, data_samples, **kwargs): + """Unpack data samples and compute loss.""" + + idx = torch.tensor([ds.image_id + for ds in data_samples]).to(img_feats.device) + idx = idx.view(-1, 1) + idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1) + pos_idx = torch.eq(idx, idx_all).float() + sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) + + with torch.no_grad(): + if self.use_distill: + sim_i2t_m = img_feats_m @ text_feats_all / self.temp + sim_t2i_m = text_feats_m @ img_feats_all / self.temp + + sim_i2t_targets = ( + self.alpha * F.softmax(sim_i2t_m, dim=1) + + (1 - self.alpha) * sim_targets) + sim_t2i_targets = ( + self.alpha * F.softmax(sim_t2i_m, dim=1) + + (1 - self.alpha) * sim_targets) + + sim_i2t = img_feats @ text_feats_all / self.temp + sim_t2i = text_feats @ img_feats_all / self.temp + + if self.use_distill: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean() + else: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean() + + # compute loss + losses = dict() + + losses['itc_loss'] = (loss_i2t + loss_t2i) / 2 + self._dequeue_and_enqueue(img_feats_m, text_feats_m, idx) + return losses + + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): + # gather keys before updating queue + image_feats = torch.cat(all_gather(image_feat)) + text_feats = torch.cat(all_gather(text_feat)) + + batch_size = image_feats.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr:ptr + batch_size] = image_feats.T + self.text_queue[:, ptr:ptr + batch_size] = text_feats.T + + if idxs is not None: + idxs = torch.cat(all_gather(idxs)) + self.idx_queue[:, ptr:ptr + batch_size] = idxs.T + + ptr = (ptr + batch_size) % self.queue_size # move pointer + self.queue_ptr[0] = ptr diff --git a/mmpretrain/models/heads/itm_head.py b/mmpretrain/models/heads/itm_head.py new file mode 100644 index 0000000..c7b42f3 --- /dev/null +++ b/mmpretrain/models/heads/itm_head.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.evaluation import Accuracy +from mmpretrain.registry import MODELS + + +class Pooler(nn.Module): + + def __init__(self, hidden_size): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@MODELS.register_module() +class ITMHead(BaseModule): + """Image-text matching head for multi-modal pre-trained task. Adapted by + BLIP, FLAVA. + + Args: + hidden_size (int): Hidden channel size out input features. + with_pooler (bool): Whether a pooler is added. Defaults to True. + loss (dict): Config of global contrasive loss. Defaults to + ``dict(type='GlobalContrasiveLoss')``. + cal_acc (bool): Whether to calculate accuracy during training. + If you use batch augmentations like Mixup and CutMix during + training, it is pointless to calculate accuracy. + Defaults to False. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + hidden_size: int, + with_pooler: bool = True, + loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), + cal_acc: bool = False, + init_cfg: Optional[dict] = None): + super(ITMHead, self).__init__(init_cfg=init_cfg) + self.hidden_size = hidden_size + + if with_pooler: + self.pooler = Pooler(hidden_size=self.hidden_size) + else: + self.pooler = nn.Identity() + self.fc = nn.Linear(self.hidden_size, 2) + + self.loss_module = MODELS.build(loss) + self.cal_acc = cal_acc + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pooler(feats[-1]) + itm_logits = self.fc(pre_logits) + return itm_logits + + def loss(self, feats: Tuple[torch.Tensor], data_samples, **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[ClsDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + # The part can be traced by torch.fx + itm_logits = self(feats) + + # deal with query + if itm_logits.ndim == 3: + itm_logits = itm_logits.mean(dim=1) + + # The part can not be traced by torch.fx + losses = self._get_loss(itm_logits, data_samples, **kwargs) + return losses + + def _get_loss(self, itm_logits: torch.Tensor, data_samples, **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + # use `itm_label` in here temporarily + target = torch.tensor([i.is_matched + for i in data_samples]).to(itm_logits.device) + + # compute loss + losses = dict() + + loss = self.loss_module( + itm_logits, target.long(), avg_factor=itm_logits.size(0), **kwargs) + losses['itm_loss'] = loss + + # compute accuracy + if self.cal_acc: + # topk is meaningless for matching task + acc = Accuracy.calculate(itm_logits, target) + # acc is warpped with two lists of topk and thrs + # which are unnecessary here + losses.update({'itm_accuracy': acc[0][0]}) + + return losses diff --git a/mmpretrain/models/heads/itpn_clip_head.py b/mmpretrain/models/heads/itpn_clip_head.py new file mode 100644 index 0000000..52c49b8 --- /dev/null +++ b/mmpretrain/models/heads/itpn_clip_head.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.device import get_device +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class iTPNClipHead(BaseModule): + """Head for iTPN Pre-training using Clip. + + Compute the logits and the cross entropy loss. + + Args: + embed_dims (int): The dimension of embedding. + num_embed (int): The number of classification types. + loss (dict): The config of loss. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_embed: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.cls_head = nn.Linear(embed_dims, num_embed) + self.loss_module = MODELS.build(loss) + + def loss(self, feats: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + feats (torch.Tensor): Features from backbone. + target (torch.Tensor): Target generated by target_generator. + mask (torch.Tensor): Generated mask for pretraing. + """ + mask = mask.to(get_device(), non_blocking=True) + mask = mask.flatten(1).to(torch.bool) + target = target[mask] + + # remove cls_token + # feats = feats[:, 1:] + logits = self.cls_head(feats[mask]) + + loss = self.loss_module(logits, target) + return loss diff --git a/mmpretrain/models/heads/latent_heads.py b/mmpretrain/models/heads/latent_heads.py new file mode 100644 index 0000000..a9662b5 --- /dev/null +++ b/mmpretrain/models/heads/latent_heads.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.dist import all_reduce, get_world_size +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class LatentPredictHead(BaseModule): + """Head for latent feature prediction. + + This head builds a predictor, which can be any registered neck component. + For example, BYOL and SimSiam call this head and build NonLinearNeck. + It also implements similarity loss between two forward features. + + Args: + loss (dict): Config dict for the loss. + predictor (dict): Config dict for the predictor. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + loss: dict, + predictor: dict, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.loss_module = MODELS.build(loss) + self.predictor = MODELS.build(predictor) + + def loss(self, input: torch.Tensor, + target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward head. + + Args: + input (torch.Tensor): NxC input features. + target (torch.Tensor): NxC target features. + + Returns: + torch.Tensor: The latent predict loss. + """ + pred = self.predictor([input])[0] + target = target.detach() + + loss = self.loss_module(pred, target) + + return loss + + +@MODELS.register_module() +class LatentCrossCorrelationHead(BaseModule): + """Head for latent feature cross correlation. + + Part of the code is borrowed from `script + `_. + + Args: + in_channels (int): Number of input channels. + loss (dict): Config dict for module of loss functions. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.world_size = get_world_size() + self.bn = nn.BatchNorm1d(in_channels, affine=False) + self.loss_module = MODELS.build(loss) + + def loss(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Forward head. + + Args: + input (torch.Tensor): NxC input features. + target (torch.Tensor): NxC target features. + + Returns: + torch.Tensor: The cross correlation loss. + """ + # cross-correlation matrix + cross_correlation_matrix = self.bn(input).T @ self.bn(target) + cross_correlation_matrix.div_(input.size(0) * self.world_size) + + all_reduce(cross_correlation_matrix) + + loss = self.loss_module(cross_correlation_matrix) + return loss diff --git a/mmpretrain/models/heads/levit_head.py b/mmpretrain/models/heads/levit_head.py new file mode 100644 index 0000000..a74d7ec --- /dev/null +++ b/mmpretrain/models/heads/levit_head.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.models.heads import ClsHead +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +class BatchNormLinear(BaseModule): + + def __init__(self, in_channels, out_channels, norm_cfg=dict(type='BN1d')): + super(BatchNormLinear, self).__init__() + self.bn = build_norm_layer(norm_cfg, in_channels) + self.linear = nn.Linear(in_channels, out_channels) + + @torch.no_grad() + def fuse(self): + w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5 + b = self.bn.bias - self.bn.running_mean * \ + self.bn.weight / (self.bn.running_var + self.bn.eps) ** 0.5 + w = self.linear.weight * w[None, :] + b = (self.linear.weight @ b[:, None]).view(-1) + self.linear.bias + + self.linear.weight.data.copy_(w) + self.linear.bias.data.copy_(b) + return self.linear + + def forward(self, x): + x = self.bn(x) + x = self.linear(x) + return x + + +def fuse_parameters(module): + for child_name, child in module.named_children(): + if hasattr(child, 'fuse'): + setattr(module, child_name, child.fuse()) + else: + fuse_parameters(child) + + +@MODELS.register_module() +class LeViTClsHead(ClsHead): + + def __init__(self, + num_classes=1000, + distillation=True, + in_channels=None, + deploy=False, + **kwargs): + super(LeViTClsHead, self).__init__(**kwargs) + self.num_classes = num_classes + self.distillation = distillation + self.deploy = deploy + self.head = BatchNormLinear(in_channels, num_classes) + if distillation: + self.head_dist = BatchNormLinear(in_channels, num_classes) + + if self.deploy: + self.switch_to_deploy(self) + + def switch_to_deploy(self): + if self.deploy: + return + fuse_parameters(self) + self.deploy = True + + def forward(self, x): + x = self.pre_logits(x) + if self.distillation: + x = self.head(x), self.head_dist(x) # 2 16 384 -> 2 1000 + if not self.training: + x = (x[0] + x[1]) / 2 + else: + raise NotImplementedError("MMPretrain doesn't support " + 'training in distillation mode.') + else: + x = self.head(x) + return x diff --git a/mmpretrain/models/heads/linear_head.py b/mmpretrain/models/heads/linear_head.py new file mode 100644 index 0000000..90b4c2b --- /dev/null +++ b/mmpretrain/models/heads/linear_head.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +@MODELS.register_module() +class LinearClsHead(ClsHead): + """Linear classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``. + cal_acc (bool): Whether to calculate accuracy during training. + If you use batch augmentations like Mixup and CutMix during + training, it is pointless to calculate accuracy. + Defaults to False. + init_cfg (dict, optional): the config to control the initialization. + Defaults to ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + init_cfg: Optional[dict] = dict( + type='Normal', layer='Linear', std=0.01), + **kwargs): + super(LinearClsHead, self).__init__(init_cfg=init_cfg, **kwargs) + + self.in_channels = in_channels + self.num_classes = num_classes + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self.fc = nn.Linear(self.in_channels, self.num_classes) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``LinearClsHead``, we just obtain the + feature of the last stage. + """ + # The LinearClsHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/mae_head.py b/mmpretrain/models/heads/mae_head.py new file mode 100644 index 0000000..b76eced --- /dev/null +++ b/mmpretrain/models/heads/mae_head.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MAEPretrainHead(BaseModule): + """Head for MAE Pre-training. + + Args: + loss (dict): Config of loss. + norm_pix_loss (bool): Whether or not normalize target. + Defaults to False. + patch_size (int): Patch size. Defaults to 16. + in_channels (int): Number of input channels. Defaults to 3. + """ + + def __init__(self, + loss: dict, + norm_pix: bool = False, + patch_size: int = 16, + in_channels: int = 3) -> None: + super().__init__() + self.norm_pix = norm_pix + self.patch_size = patch_size + self.in_channels = in_channels + self.loss_module = MODELS.build(loss) + + def patchify(self, imgs: torch.Tensor) -> torch.Tensor: + r"""Split images into non-overlapped patches. + + Args: + imgs (torch.Tensor): A batch of images. The shape should + be :math:`(B, C, H, W)`. + + Returns: + torch.Tensor: Patchified images. The shape is + :math:`(B, L, \text{patch_size}^2 \times C)`. + """ + p = self.patch_size + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels)) + return x + + def unpatchify(self, x: torch.Tensor) -> torch.Tensor: + r"""Combine non-overlapped patches into images. + + Args: + x (torch.Tensor): The shape is + :math:`(B, L, \text{patch_size}^2 \times C)`. + + Returns: + torch.Tensor: The shape is :math:`(B, C, H, W)`. + """ + p = self.patch_size + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p)) + return imgs + + def construct_target(self, target: torch.Tensor) -> torch.Tensor: + """Construct the reconstruction target. + + In addition to splitting images into tokens, this module will also + normalize the image according to ``norm_pix``. + + Args: + target (torch.Tensor): Image with the shape of B x C x H x W + + Returns: + torch.Tensor: Tokenized images with the shape of B x L x C + """ + target = self.patchify(target) + if self.norm_pix: + # normalize the target image + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + return target + + def loss(self, pred: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + target = self.construct_target(target) + loss = self.loss_module(pred, target, mask) + + return loss diff --git a/mmpretrain/models/heads/margin_head.py b/mmpretrain/models/heads/margin_head.py new file mode 100644 index 0000000..3a88bf8 --- /dev/null +++ b/mmpretrain/models/heads/margin_head.py @@ -0,0 +1,300 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.fileio import list_from_file +from mmengine.runner import autocast +from mmengine.utils import is_seq_of + +from mmpretrain.models.losses import convert_to_one_hot +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .cls_head import ClsHead + + +class NormProduct(nn.Linear): + """An enhanced linear layer with k clustering centers to calculate product + between normalized input and linear weight. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample + k (int): The number of clustering centers. Defaults to 1. + bias (bool): Whether there is bias. If set to ``False``, the + layer will not learn an additive bias. Defaults to ``True``. + feature_norm (bool): Whether to normalize the input feature. + Defaults to ``True``. + weight_norm (bool):Whether to normalize the weight. + Defaults to ``True``. + """ + + def __init__(self, + in_features: int, + out_features: int, + k=1, + bias: bool = False, + feature_norm: bool = True, + weight_norm: bool = True): + + super().__init__(in_features, out_features * k, bias=bias) + self.weight_norm = weight_norm + self.feature_norm = feature_norm + self.out_features = out_features + self.k = k + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.feature_norm: + input = F.normalize(input) + if self.weight_norm: + weight = F.normalize(self.weight) + else: + weight = self.weight + cosine_all = F.linear(input, weight, self.bias) + + if self.k == 1: + return cosine_all + else: + cosine_all = cosine_all.view(-1, self.out_features, self.k) + cosine, _ = torch.max(cosine_all, dim=2) + return cosine + + +@MODELS.register_module() +class ArcFaceClsHead(ClsHead): + """ArcFace classifier head. + + A PyTorch implementation of paper `ArcFace: Additive Angular Margin Loss + for Deep Face Recognition `_ and + `Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web + Faces `_ + + Example: + To use ArcFace in config files. + + 1. use vanilla ArcFace + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + 2. use SubCenterArcFace with 3 sub-centers + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + num_subcenters=3, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + 3. use SubCenterArcFace With CountPowerAdaptiveMargins + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + num_subcenters=3, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + custom_hooks = [dict(type='SetAdaptiveMarginsHook')] + + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + num_subcenters (int): Number of subcenters. Defaults to 1. + scale (float): Scale factor of output logit. Defaults to 64.0. + margins (float): The penalty margin. Could be the fllowing formats: + + - float: The margin, would be same for all the categories. + - Sequence[float]: The category-based margins list. + - str: A '.txt' file path which contains a list. Each line + represents the margin of a category, and the number in the + i-th row indicates the margin of the i-th class. + + Defaults to 0.5. + easy_margin (bool): Avoid theta + m >= PI. Defaults to False. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + num_subcenters: int = 1, + scale: float = 64., + margins: Optional[Union[float, Sequence[float], str]] = 0.50, + easy_margin: bool = False, + loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg: Optional[dict] = None): + + super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg) + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + assert num_subcenters >= 1 and num_classes >= 0 + self.in_channels = in_channels + self.num_classes = num_classes + self.num_subcenters = num_subcenters + self.scale = scale + self.easy_margin = easy_margin + + self.norm_product = NormProduct(in_channels, num_classes, + num_subcenters) + + if isinstance(margins, float): + margins = [margins] * num_classes + elif isinstance(margins, str) and margins.endswith('.txt'): + margins = [float(item) for item in list_from_file(margins)] + else: + assert is_seq_of(list(margins), (float, int)), ( + 'the attribute `margins` in ``ArcFaceClsHead`` should be a ' + ' float, a Sequence of float, or a ".txt" file path.') + + assert len(margins) == num_classes, \ + 'The length of margins must be equal with num_classes.' + + self.register_buffer( + 'margins', torch.tensor(margins).float(), persistent=False) + # To make `phi` monotonic decreasing, refers to + # https://github.com/deepinsight/insightface/issues/108 + sinm_m = torch.sin(math.pi - self.margins) * self.margins + threshold = torch.cos(math.pi - self.margins) + self.register_buffer('sinm_m', sinm_m, persistent=False) + self.register_buffer('threshold', threshold, persistent=False) + + def set_margins(self, margins: Union[Sequence[float], float]) -> None: + """set margins of arcface head. + + Args: + margins (Union[Sequence[float], float]): The marigins. + """ + if isinstance(margins, float): + margins = [margins] * self.num_classes + assert is_seq_of( + list(margins), float) and (len(margins) == self.num_classes), ( + f'margins must be Sequence[Union(float, int)], get {margins}') + + self.margins = torch.tensor( + margins, device=self.margins.device, dtype=torch.float32) + self.sinm_m = torch.sin(self.margins) * self.margins + self.threshold = -torch.cos(self.margins) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``ArcFaceHead``, we just obtain the + feature of the last stage. + """ + # The ArcFaceHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def _get_logit_with_margin(self, pre_logits, target): + """add arc margin to the cosine in target index. + + The target must be in index format. + """ + assert target.dim() == 1 or ( + target.dim() == 2 and target.shape[1] == 1), \ + 'The target must be in index format.' + cosine = self.norm_product(pre_logits) + phi = torch.cos(torch.acos(cosine) + self.margins) + + if self.easy_margin: + # when cosine>0, choose phi + # when cosine<=0, choose cosine + phi = torch.where(cosine > 0, phi, cosine) + else: + # when cos>th, choose phi + # when cos<=th, choose cosine-mm + phi = torch.where(cosine > self.threshold, phi, + cosine - self.sinm_m) + + target = convert_to_one_hot(target, self.num_classes) + output = target * phi + (1 - target) * cosine + return output + + def forward(self, + feats: Tuple[torch.Tensor], + target: Optional[torch.Tensor] = None) -> torch.Tensor: + """The forward process.""" + # Disable AMP + with autocast(enabled=False): + pre_logits = self.pre_logits(feats) + + if target is None: + # when eval, logit is the cosine between W and pre_logits; + # cos(theta_yj) = (x/||x||) * (W/||W||) + logit = self.norm_product(pre_logits) + else: + # when training, add a margin to the pre_logits where target is + # True, then logit is the cosine between W and new pre_logits + logit = self._get_logit_with_margin(pre_logits, target) + + return self.scale * logit + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # Unpack data samples and pack targets + label_target = torch.cat([i.gt_label for i in data_samples]) + if 'gt_score' in data_samples[0]: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_score for i in data_samples]) + else: + target = label_target + + # the index format target would be used + cls_score = self(feats, label_target) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + return losses diff --git a/mmpretrain/models/heads/mim_head.py b/mmpretrain/models/heads/mim_head.py new file mode 100644 index 0000000..bda90c8 --- /dev/null +++ b/mmpretrain/models/heads/mim_head.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MIMHead(BaseModule): + """Pre-training head for Masked Image Modeling. + + Args: + loss (dict): Config dict for module of loss functions. + """ + + def __init__(self, loss: dict) -> None: + super().__init__() + self.loss_module = MODELS.build(loss) + + def loss(self, + pred: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward head. + + Args: + pred (torch.Tensor): Predictions with shape B x L x C. + target (torch.Tensor): Targets with shape B x L x C. + mask (torch.Tensor): Mask with shape B x L. + + Returns: + torch.Tensor: The loss tensor. + """ + loss = self.loss_module(pred, target, mask) + return loss diff --git a/mmpretrain/models/heads/mixmim_head.py b/mmpretrain/models/heads/mixmim_head.py new file mode 100644 index 0000000..a709630 --- /dev/null +++ b/mmpretrain/models/heads/mixmim_head.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmpretrain.registry import MODELS +from .mae_head import MAEPretrainHead + + +@MODELS.register_module() +class MixMIMPretrainHead(MAEPretrainHead): + """Head for MixMIM Pre-training. + + Args: + loss (dict): Config of loss. + norm_pix_loss (bool): Whether or not normalize target. + Defaults to False. + patch_size (int): Patch size. Defaults to 16. + """ + + def __init__(self, + loss: dict, + norm_pix: bool = False, + patch_size: int = 16) -> None: + super().__init__(loss=loss, norm_pix=norm_pix, patch_size=patch_size) + + def loss(self, x_rec: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + target = self.construct_target(target) + + B, L, C = x_rec.shape + + # unmix tokens + x1_rec = x_rec[:B // 2] + x2_rec = x_rec[B // 2:] + + unmix_x_rec = x1_rec * mask + x2_rec.flip(0) * (1 - mask) + + loss_rec = self.loss_module(unmix_x_rec, target) + + return loss_rec diff --git a/mmpretrain/models/heads/mocov3_head.py b/mmpretrain/models/heads/mocov3_head.py new file mode 100644 index 0000000..c2bec2a --- /dev/null +++ b/mmpretrain/models/heads/mocov3_head.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.dist import all_gather, get_rank +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MoCoV3Head(BaseModule): + """Head for MoCo v3 Pre-training. + + This head builds a predictor, which can be any registered neck component. + It also implements latent contrastive loss between two forward features. + Part of the code is modified from: + ``_. + + Args: + predictor (dict): Config dict for module of predictor. + loss (dict): Config dict for module of loss functions. + temperature (float): The temperature hyper-parameter that + controls the concentration level of the distribution. + Defaults to 1.0. + """ + + def __init__(self, + predictor: dict, + loss: dict, + temperature: float = 1.0) -> None: + super().__init__() + self.predictor = MODELS.build(predictor) + self.loss_module = MODELS.build(loss) + self.temperature = temperature + + def loss(self, base_out: torch.Tensor, + momentum_out: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + base_out (torch.Tensor): NxC features from base_encoder. + momentum_out (torch.Tensor): NxC features from momentum_encoder. + + Returns: + torch.Tensor: The loss tensor. + """ + # predictor computation + pred = self.predictor([base_out])[0] + + # normalize + pred = nn.functional.normalize(pred, dim=1) + target = nn.functional.normalize(momentum_out, dim=1) + + # get negative samples + target = torch.cat(all_gather(target), dim=0) + + # Einstein sum is more intuitive + logits = torch.einsum('nc,mc->nm', [pred, target]) / self.temperature + + # generate labels + batch_size = logits.shape[0] + labels = (torch.arange(batch_size, dtype=torch.long) + + batch_size * get_rank()).to(logits.device) + + loss = self.loss_module(logits, labels) + return loss diff --git a/mmpretrain/models/heads/multi_label_cls_head.py b/mmpretrain/models/heads/multi_label_cls_head.py new file mode 100644 index 0000000..ca36bfe --- /dev/null +++ b/mmpretrain/models/heads/multi_label_cls_head.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample, label_to_onehot + + +@MODELS.register_module() +class MultiLabelClsHead(BaseModule): + """Classification head for multilabel task. + + Args: + loss (dict): Config of classification loss. Defaults to + dict(type='CrossEntropyLoss', use_sigmoid=True). + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + init_cfg (dict, optional): The extra init config of layers. + Defaults to None. + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. + """ + + def __init__(self, + loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True), + thr: Optional[float] = None, + topk: Optional[int] = None, + init_cfg: Optional[dict] = None): + super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg) + + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + if thr is None and topk is None: + thr = 0.5 + + self.thr = thr + self.topk = topk + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``MultiLabelClsHead``, we just obtain + the feature of the last stage. + """ + # The MultiLabelClsHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The MultiLabelClsHead doesn't have the final classification head, + # just return the unpacked inputs. + return pre_logits + + def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[DataSample], **kwargs): + """Unpack data samples and compute loss.""" + num_classes = cls_score.size()[-1] + # Unpack data samples and pack targets + if 'gt_score' in data_samples[0]: + target = torch.stack([i.gt_score.float() for i in data_samples]) + else: + target = torch.stack([ + label_to_onehot(i.gt_label, num_classes) for i in data_samples + ]).float() + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + return losses + + def predict(self, + feats: Tuple[torch.Tensor], + data_samples: List[DataSample] = None) -> List[DataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[DataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[DataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score: torch.Tensor, + data_samples: List[DataSample]): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = torch.sigmoid(cls_score) + + if data_samples is None: + data_samples = [DataSample() for _ in range(cls_score.size(0))] + + for data_sample, score in zip(data_samples, pred_scores): + if self.thr is not None: + # a label is predicted positive if larger than thr + label = torch.where(score >= self.thr)[0] + else: + # top-k labels will be predicted positive for any example + _, label = score.topk(self.topk) + data_sample.set_pred_score(score).set_pred_label(label) + + return data_samples diff --git a/mmpretrain/models/heads/multi_label_csra_head.py b/mmpretrain/models/heads/multi_label_csra_head.py new file mode 100644 index 0000000..95a3a0e --- /dev/null +++ b/mmpretrain/models/heads/multi_label_csra_head.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/Kevinz-code/CSRA +from typing import Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from .multi_label_cls_head import MultiLabelClsHead + + +@MODELS.register_module() +class CSRAClsHead(MultiLabelClsHead): + """Class-specific residual attention classifier head. + + Please refer to the `Residual Attention: A Simple but Effective Method for + Multi-Label Recognition (ICCV 2021) `_ + for details. + + Args: + num_classes (int): Number of categories. + in_channels (int): Number of channels in the input feature map. + num_heads (int): Number of residual at tensor heads. + loss (dict): Config of classification loss. + lam (float): Lambda that combines global average and max pooling + scores. + init_cfg (dict, optional): The extra init config of layers. + Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``. + """ + temperature_settings = { # softmax temperature settings + 1: [1], + 2: [1, 99], + 4: [1, 2, 4, 99], + 6: [1, 2, 3, 4, 5, 99], + 8: [1, 2, 3, 4, 5, 6, 7, 99] + } + + def __init__(self, + num_classes: int, + in_channels: int, + num_heads: int, + lam: float, + init_cfg=dict(type='Normal', layer='Linear', std=0.01), + **kwargs): + assert num_heads in self.temperature_settings.keys( + ), 'The num of heads is not in temperature setting.' + assert lam > 0, 'Lambda should be between 0 and 1.' + super(CSRAClsHead, self).__init__(init_cfg=init_cfg, **kwargs) + self.temp_list = self.temperature_settings[num_heads] + self.csra_heads = ModuleList([ + CSRAModule(num_classes, in_channels, self.temp_list[i], lam) + for i in range(num_heads) + ]) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``CSRAClsHead``, we just obtain the + feature of the last stage. + """ + # The CSRAClsHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + logit = sum([head(pre_logits) for head in self.csra_heads]) + return logit + + +class CSRAModule(BaseModule): + """Basic module of CSRA with different temperature. + + Args: + num_classes (int): Number of categories. + in_channels (int): Number of channels in the input feature map. + T (int): Temperature setting. + lam (float): Lambda that combines global average and max pooling + scores. + init_cfg (dict | optional): The extra init config of layers. + Defaults to use dict(type='Normal', layer='Linear', std=0.01). + """ + + def __init__(self, + num_classes: int, + in_channels: int, + T: int, + lam: float, + init_cfg=None): + + super(CSRAModule, self).__init__(init_cfg=init_cfg) + self.T = T # temperature + self.lam = lam # Lambda + self.head = nn.Conv2d(in_channels, num_classes, 1, bias=False) + self.softmax = nn.Softmax(dim=2) + + def forward(self, x): + score = self.head(x) / torch.norm( + self.head.weight, dim=1, keepdim=True).transpose(0, 1) + score = score.flatten(2) + base_logit = torch.mean(score, dim=2) + + if self.T == 99: # max-pooling + att_logit = torch.max(score, dim=2)[0] + else: + score_soft = self.softmax(score * self.T) + att_logit = torch.sum(score * score_soft, dim=2) + + return base_logit + self.lam * att_logit diff --git a/mmpretrain/models/heads/multi_label_linear_head.py b/mmpretrain/models/heads/multi_label_linear_head.py new file mode 100644 index 0000000..81217ec --- /dev/null +++ b/mmpretrain/models/heads/multi_label_linear_head.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .multi_label_cls_head import MultiLabelClsHead + + +@MODELS.register_module() +class MultiLabelLinearClsHead(MultiLabelClsHead): + """Linear classification head for multilabel task. + + Args: + loss (dict): Config of classification loss. Defaults to + dict(type='CrossEntropyLoss', use_sigmoid=True). + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + init_cfg (dict, optional): The extra init config of layers. + Defaults to use dict(type='Normal', layer='Linear', std=0.01). + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True), + thr: Optional[float] = None, + topk: Optional[int] = None, + init_cfg: Optional[dict] = dict( + type='Normal', layer='Linear', std=0.01)): + super(MultiLabelLinearClsHead, self).__init__( + loss=loss, thr=thr, topk=topk, init_cfg=init_cfg) + + assert num_classes > 0, f'num_classes ({num_classes}) must be a ' \ + 'positive integer.' + + self.in_channels = in_channels + self.num_classes = num_classes + + self.fc = nn.Linear(self.in_channels, self.num_classes) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``MultiLabelLinearClsHead``, we just + obtain the feature of the last stage. + """ + # The obtain the MultiLabelLinearClsHead doesn't have other module, + # just return after unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/multi_task_head.py b/mmpretrain/models/heads/multi_task_head.py new file mode 100644 index 0000000..3515a2b --- /dev/null +++ b/mmpretrain/models/heads/multi_task_head.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule, ModuleDict + +from mmpretrain.registry import MODELS +from mmpretrain.structures import MultiTaskDataSample + + +def loss_convertor(loss_func, task_name): + + def wrapped(inputs, data_samples, **kwargs): + mask = torch.empty(len(data_samples), dtype=torch.bool) + task_data_samples = [] + for i, data_sample in enumerate(data_samples): + assert isinstance(data_sample, MultiTaskDataSample) + sample_mask = task_name in data_sample + mask[i] = sample_mask + if sample_mask: + task_data_samples.append(data_sample.get(task_name)) + + if len(task_data_samples) == 0: + # This makes it possible to perform loss.backward when a + # task does not have gt_labels within a batch. + loss = (inputs[0] * 0).sum() + return {'loss': loss, 'mask_size': torch.tensor(0.)} + + # Mask the inputs of the task + def mask_inputs(inputs, mask): + if isinstance(inputs, Sequence): + return type(inputs)( + [mask_inputs(input, mask) for input in inputs]) + elif isinstance(inputs, torch.Tensor): + return inputs[mask] + + masked_inputs = mask_inputs(inputs, mask) + loss_output = loss_func(masked_inputs, task_data_samples, **kwargs) + loss_output['mask_size'] = mask.sum().to(torch.float) + return loss_output + + return wrapped + + +@MODELS.register_module() +class MultiTaskHead(BaseModule): + """Multi task head. + + Args: + task_heads (dict): Sub heads to use, the key will be use to rename the + loss components. + common_cfg (dict): The common settings for all heads. Defaults to an + empty dict. + init_cfg (dict, optional): The extra initialization settings. + Defaults to None. + """ + + def __init__(self, task_heads, init_cfg=None, **kwargs): + super(MultiTaskHead, self).__init__(init_cfg=init_cfg) + + assert isinstance(task_heads, dict), 'The `task_heads` argument' \ + "should be a dict, which's keys are task names and values are" \ + 'configs of head for the task.' + + self.task_heads = ModuleDict() + + for task_name, sub_head in task_heads.items(): + if not isinstance(sub_head, nn.Module): + sub_head = MODELS.build(sub_head, default_args=kwargs) + sub_head.loss = loss_convertor(sub_head.loss, task_name) + self.task_heads[task_name] = sub_head + + def forward(self, feats): + """The forward process.""" + return { + task_name: head(feats) + for task_name, head in self.task_heads.items() + } + + def loss(self, feats: Tuple[torch.Tensor], + data_samples: List[MultiTaskDataSample], **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + data_samples (List[MultiTaskDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components, each task loss + key will be prefixed by the task_name like "task1_loss" + """ + losses = dict() + for task_name, head in self.task_heads.items(): + head_loss = head.loss(feats, data_samples, **kwargs) + for k, v in head_loss.items(): + losses[f'{task_name}_{k}'] = v + return losses + + def predict( + self, + feats: Tuple[torch.Tensor], + data_samples: List[MultiTaskDataSample] = None + ) -> List[MultiTaskDataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + data_samples (List[MultiTaskDataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[MultiTaskDataSample]: A list of data samples which contains + the predicted results. + """ + predictions_dict = dict() + + for task_name, head in self.task_heads.items(): + task_samples = None + if data_samples is not None: + task_samples = [ + data_sample.get(task_name, None) if data_sample else None + for data_sample in data_samples + ] + + task_samples = head.predict(feats, task_samples) + batch_size = len(task_samples) + predictions_dict[task_name] = task_samples + + if data_samples is None: + data_samples = [MultiTaskDataSample() for _ in range(batch_size)] + else: + data_samples = [ + MultiTaskDataSample() if data_sample is None else data_sample + for data_sample in data_samples + ] + + for task_name, task_samples in predictions_dict.items(): + for data_sample, task_sample in zip(data_samples, task_samples): + task_sample.set_field( + task_name in data_sample.tasks, + 'eval_mask', + field_type='metainfo') + + if task_name in data_sample.tasks: + data_sample.get(task_name).update(task_sample) + else: + data_sample.set_field(task_sample, task_name) + + return data_samples diff --git a/mmpretrain/models/heads/seq_gen_head.py b/mmpretrain/models/heads/seq_gen_head.py new file mode 100644 index 0000000..b2e9b10 --- /dev/null +++ b/mmpretrain/models/heads/seq_gen_head.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SeqGenerationHead(BaseModule): + """Generation head for multi-modal pre-trained task, adopted by BLIP. + Normally used for generation task. + + Args: + decoder (dict): Decoder for blip generation head. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__( + self, + decoder: dict, + ignore_index=-100, + loss: dict = dict(type='LabelSmoothLoss', label_smooth_val=0.1), + init_cfg: Optional[dict] = None, + ) -> None: + super(SeqGenerationHead, self).__init__(init_cfg=init_cfg) + self.decoder = MODELS.build(decoder) + self.loss_fn = MODELS.build(loss) + self.ignore_index = ignore_index + + def forward(self, input_ids: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, labels: torch.Tensor): + """Forward to get decoder output. + + Args: + input_ids (torch.Tensor): The tokenized input text tensor. + encoder_hidden_states (torch.Tensor): Hidden states from image + embeddings. + encoder_attention_mask (torch.Tensor): Image embeddings hidden + states attention mask. + labels (torch.Tensor): Decoder target for calculate loss. + + Returns: + dict[str, Tensor]: a dictionary of decoder outputs. + """ + + decoder_out = self.decoder( + input_ids=input_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + labels=labels, + return_dict=True, + ) + return decoder_out + + def loss(self, input_ids, encoder_hidden_states, encoder_attention_mask, + labels): + """Calculate losses from the extracted features. + + Args: + input_ids (torch.Tensor): The tokenized input text tensor. + encoder_hidden_states (torch.Tensor): Hidden states from image + embeddings. + encoder_attention_mask (torch.Tensor): Image embeddings hidden + states attention mask. + labels (torch.Tensor): Decoder target for calculate loss. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + + decoder_out = self( + input_ids=input_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + labels=labels, + ) + prediction_scores = decoder_out['logits'] + # we are doing next-token prediction; + # shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + + vocab_size = prediction_scores.shape[-1] + + # mask ignored index + if (labels == self.ignore_index).any(): + labels = labels.view(-1).clone() + ignore_mask = (labels == self.ignore_index) + labels.masked_fill_(ignore_mask, 0) + weight = torch.logical_not(ignore_mask) + avg_factor = max(weight.sum(), 1) + else: + weight = None + avg_factor = labels.size(0) + + lm_loss = self.loss_fn( + shifted_prediction_scores.view(-1, vocab_size), + labels, + weight=weight, + avg_factor=avg_factor, + ) + losses = { + 'seq_gen_lm_loss': lm_loss, + } + + return losses + + def predict(self, + input_ids, + encoder_hidden_states, + sep_token_id, + pad_token_id, + use_nucleus_sampling=False, + num_beams=3, + max_length=20, + min_length=2, + top_p=0.9, + repetition_penalty=1.0, + **kwargs): + """Decoder prediction method. + + Args: + input_ids (torch.Tensor): The tokenized input text tensor. + encoder_hidden_states (torch.Tensor): Hidden states from image + embeddings. + sep_token_id (int): Tokenid of separation token. + pad_token_id (int): Tokenid of pad token. + use_nucleus_sampling (bool): Whether to use nucleus sampling in + prediction. Defaults to False. + num_beams (int): Number of beams used in predition. + Defaults to 3. + max_length (int): Max length of generated text in predition. + Defaults to 20. + min_length (int): Min length of generated text in predition. + Defaults to 20. + top_p (float): + If < 1.0, only keep the top tokens with cumulative probability + >= top_p (nucleus filtering). Defaults to 0.9. + repetition_penalty (float): The parameter for repetition penalty. + Defaults to 1.0. + **kwarg: Other arguments that might used in generation. + + Returns: + dict[str, Tensor]: a dictionary of generation outputs. + """ + device = encoder_hidden_states.device + + # TODO: In old version of transformers + # Additional repeat interleave of hidden states should be add here. + image_atts = torch.ones( + encoder_hidden_states.size()[:-1], dtype=torch.long).to(device) + + model_kwargs = { + 'encoder_hidden_states': encoder_hidden_states, + 'encoder_attention_mask': image_atts, + } + model_kwargs.update(kwargs) + + if use_nucleus_sampling: + # nucleus sampling + outputs = self.decoder.generate( + input_ids=input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + # beam search + outputs = self.decoder.generate( + input_ids=input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + return outputs diff --git a/mmpretrain/models/heads/simmim_head.py b/mmpretrain/models/heads/simmim_head.py new file mode 100644 index 0000000..b7af984 --- /dev/null +++ b/mmpretrain/models/heads/simmim_head.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SimMIMHead(BaseModule): + """Head for SimMIM Pre-training. + + Args: + patch_size (int): Patch size of each token. + loss (dict): The config for loss. + """ + + def __init__(self, patch_size: int, loss: dict) -> None: + super().__init__() + self.patch_size = patch_size + self.loss_module = MODELS.build(loss) + + def loss(self, pred: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + This method will expand mask to the size of the original image. + + Args: + pred (torch.Tensor): The reconstructed image (B, C, H, W). + target (torch.Tensor): The target image (B, C, H, W). + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave( + self.patch_size, 2).unsqueeze(1).contiguous() + loss = self.loss_module(pred, target, mask) + + return loss diff --git a/mmpretrain/models/heads/spark_head.py b/mmpretrain/models/heads/spark_head.py new file mode 100644 index 0000000..a274876 --- /dev/null +++ b/mmpretrain/models/heads/spark_head.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SparKPretrainHead(BaseModule): + """Pre-training head for SparK. + + Args: + loss (dict): Config of loss. + norm_pix (bool): Whether or not normalize target. Defaults to True. + patch_size (int): Patch size, equal to downsample ratio of backbone. + Defaults to 32. + """ + + def __init__(self, + loss: dict, + norm_pix: bool = True, + patch_size: int = 32) -> None: + super().__init__() + self.norm_pix = norm_pix + self.patch_size = patch_size + self.loss = MODELS.build(loss) + + def patchify(self, imgs): + """Split images into non-overlapped patches. + + Args: + imgs (torch.Tensor): A batch of images, of shape B x C x H x W. + Returns: + torch.Tensor: Patchified images. The shape is B x L x D. + """ + p = self.patch_size + assert len(imgs.shape + ) == 4 and imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0 + + B, C, ori_h, ori_w = imgs.shape + h = ori_h // p + w = ori_w // p + x = imgs.reshape(shape=(B, C, h, p, w, p)) + x = torch.einsum('bchpwq->bhwpqc', x) + + # (B, f*f, downsample_raito*downsample_raito*3) + x = x.reshape(shape=(B, h * w, p**2 * C)) + return x + + def construct_target(self, target: torch.Tensor) -> torch.Tensor: + """Construct the reconstruction target. + + In addition to splitting images into tokens, this module will also + normalize the image according to ``norm_pix``. + Args: + target (torch.Tensor): Image with the shape of B x 3 x H x W + Returns: + torch.Tensor: Tokenized images with the shape of B x L x C + """ + target = self.patchify(target) + if self.norm_pix: + # normalize the target image + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + return target + + def forward(self, pred: torch.Tensor, target: torch.Tensor, + active_mask: torch.Tensor) -> torch.Tensor: + """Forward function of MAE head. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + active_mask (torch.Tensor): The mask of the target image. + Returns: + torch.Tensor: The reconstruction loss. + """ + # (B, C, H, W) -> (B, L, C) and perform normalization + target = self.construct_target(target) + + # (B, C, H, W) -> (B, L, C) + pred = self.patchify(pred) + + # (B, 1, f, f) -> (B, L) + non_active_mask = active_mask.logical_not().int().view( + active_mask.shape[0], -1) + + # MSE loss on masked patches + loss = self.loss(pred, target, non_active_mask) + return loss diff --git a/mmpretrain/models/heads/stacked_head.py b/mmpretrain/models/heads/stacked_head.py new file mode 100644 index 0000000..6cd819d --- /dev/null +++ b/mmpretrain/models/heads/stacked_head.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +class LinearBlock(BaseModule): + """Linear block for StackedLinearClsHead.""" + + def __init__(self, + in_channels, + out_channels, + dropout_rate=0., + norm_cfg=None, + act_cfg=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.fc = nn.Linear(in_channels, out_channels) + + self.norm = None + self.act = None + self.dropout = None + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + if act_cfg is not None: + self.act = build_activation_layer(act_cfg) + if dropout_rate > 0: + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, x): + """The forward process.""" + x = self.fc(x) + if self.norm is not None: + x = self.norm(x) + if self.act is not None: + x = self.act(x) + if self.dropout is not None: + x = self.dropout(x) + return x + + +@MODELS.register_module() +class StackedLinearClsHead(ClsHead): + """Classifier head with several hidden fc layer and a output fc layer. + + Args: + num_classes (int): Number of categories. + in_channels (int): Number of channels in the input feature map. + mid_channels (Sequence[int]): Number of channels in the hidden fc + layers. + dropout_rate (float): Dropout rate after each hidden fc layer, + except the last layer. Defaults to 0. + norm_cfg (dict, optional): Config dict of normalization layer after + each hidden fc layer, except the last layer. Defaults to None. + act_cfg (dict, optional): Config dict of activation function after each + hidden layer, except the last layer. Defaults to use "ReLU". + """ + + def __init__(self, + num_classes: int, + in_channels: int, + mid_channels: Sequence[int], + dropout_rate: float = 0., + norm_cfg: Optional[Dict] = None, + act_cfg: Optional[Dict] = dict(type='ReLU'), + **kwargs): + super(StackedLinearClsHead, self).__init__(**kwargs) + self.num_classes = num_classes + self.in_channels = in_channels + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + assert isinstance(mid_channels, Sequence), \ + f'`mid_channels` of StackedLinearClsHead should be a sequence, ' \ + f'instead of {type(mid_channels)}' + self.mid_channels = mid_channels + + self.dropout_rate = dropout_rate + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self._init_layers() + + def _init_layers(self): + """"Init layers.""" + self.layers = ModuleList() + in_channels = self.in_channels + for hidden_channels in self.mid_channels: + self.layers.append( + LinearBlock( + in_channels, + hidden_channels, + dropout_rate=self.dropout_rate, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + in_channels = hidden_channels + + self.layers.append( + LinearBlock( + self.mid_channels[-1], + self.num_classes, + dropout_rate=0., + norm_cfg=None, + act_cfg=None)) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. + """ + x = feats[-1] + for layer in self.layers[:-1]: + x = layer(x) + return x + + @property + def fc(self): + """Full connected layer.""" + return self.layers[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/swav_head.py b/mmpretrain/models/heads/swav_head.py new file mode 100644 index 0000000..8f3a302 --- /dev/null +++ b/mmpretrain/models/heads/swav_head.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SwAVHead(BaseModule): + """Head for SwAV Pre-training. + + Args: + loss (dict): Config dict for module of loss functions. + """ + + def __init__(self, loss: dict) -> None: + super().__init__() + self.loss_module = MODELS.build(loss) + + def loss(self, pred: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + pred (torch.Tensor): NxC input features. + + Returns: + torch.Tensor: The SwAV loss. + """ + loss = self.loss_module(pred) + + return loss diff --git a/mmpretrain/models/heads/vig_head.py b/mmpretrain/models/heads/vig_head.py new file mode 100644 index 0000000..ecb984d --- /dev/null +++ b/mmpretrain/models/heads/vig_head.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +@MODELS.register_module() +class VigClsHead(ClsHead): + """The classification head for Vision GNN. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + hidden_dim (int): The number of middle channels. Defaults to 1024. + act_cfg (dict): The config of activation function. + Defaults to ``dict(type='GELU')``. + dropout (float): The dropout rate. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + hidden_dim: int = 1024, + act_cfg: dict = dict(type='GELU'), + dropout: float = 0., + **kwargs): + super().__init__(**kwargs) + + self.fc1 = nn.Linear(in_channels, hidden_dim) + self.bn = nn.BatchNorm1d(hidden_dim) + self.act = build_activation_layer(act_cfg) + self.drop = nn.Dropout(dropout) + self.fc2 = nn.Linear(hidden_dim, num_classes) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a stage_blocks stage. In ``VigClsHead``, we just obtain the + feature of the last stage. + """ + feats = feats[-1] + feats = self.fc1(feats) + feats = self.bn(feats) + feats = self.act(feats) + feats = self.drop(feats) + + return feats + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc2(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/vision_transformer_head.py b/mmpretrain/models/heads/vision_transformer_head.py new file mode 100644 index 0000000..83e8fca --- /dev/null +++ b/mmpretrain/models/heads/vision_transformer_head.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from collections import OrderedDict +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer +from mmengine.model import Sequential +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from .cls_head import ClsHead + + +@MODELS.register_module() +class VisionTransformerClsHead(ClsHead): + """Vision Transformer classifier head. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + hidden_dim (int, optional): Number of the dimensions for hidden layer. + Defaults to None, which means no extra hidden layer. + act_cfg (dict): The activation config. Only available during + pre-training. Defaults to ``dict(type='Tanh')``. + init_cfg (dict): The extra initialization configs. Defaults to + ``dict(type='Constant', layer='Linear', val=0)``. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + hidden_dim: Optional[int] = None, + act_cfg: dict = dict(type='Tanh'), + init_cfg: dict = dict(type='Constant', layer='Linear', val=0), + **kwargs): + super(VisionTransformerClsHead, self).__init__( + init_cfg=init_cfg, **kwargs) + self.in_channels = in_channels + self.num_classes = num_classes + self.hidden_dim = hidden_dim + self.act_cfg = act_cfg + + if self.num_classes <= 0: + raise ValueError( + f'num_classes={num_classes} must be a positive integer') + + self._init_layers() + + def _init_layers(self): + """"Init hidden layer if exists.""" + if self.hidden_dim is None: + layers = [('head', nn.Linear(self.in_channels, self.num_classes))] + else: + layers = [ + ('pre_logits', nn.Linear(self.in_channels, self.hidden_dim)), + ('act', build_activation_layer(self.act_cfg)), + ('head', nn.Linear(self.hidden_dim, self.num_classes)), + ] + self.layers = Sequential(OrderedDict(layers)) + + def init_weights(self): + """"Init weights of hidden layer if exists.""" + super(VisionTransformerClsHead, self).init_weights() + # Modified from ClassyVision + if hasattr(self.layers, 'pre_logits'): + # Lecun norm + trunc_normal_( + self.layers.pre_logits.weight, + std=math.sqrt(1 / self.layers.pre_logits.in_features)) + nn.init.zeros_(self.layers.pre_logits.bias) + + def pre_logits(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of list of tensor, and each tensor is + the feature of a backbone stage. In ``VisionTransformerClsHead``, we + obtain the feature of the last stage and forward in hidden layer if + exists. + """ + feat = feats[-1] # Obtain feature of the last scale. + # For backward-compatibility with the previous ViT output + cls_token = feat[-1] if isinstance(feat, list) else feat + if self.hidden_dim is None: + return cls_token + else: + x = self.layers.pre_logits(cls_token) + return self.layers.act(x) + + def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.layers.head(pre_logits) + return cls_score diff --git a/mmpretrain/models/heads/vqa_head.py b/mmpretrain/models/heads/vqa_head.py new file mode 100644 index 0000000..c7b5fe5 --- /dev/null +++ b/mmpretrain/models/heads/vqa_head.py @@ -0,0 +1,246 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import mmengine +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class VQAGenerationHead(BaseModule): + """Generation head for multi-modal pre-trained task, adapted by BLIP. + Normally used for qa generation task (open-set) + + Args: + decoder (dict): Decoder for decoding answers. + inference_method (str): Inference method. One of 'rank', 'generate'. + - If 'rank', the model will return answers with the highest + probability from the answer list. + - If 'generate', the model will generate answers. + - Only for test, not for train / val. + num_beams (int): Number of beams for beam search. 1 means no beam + search. Only support when inference_method=='generate'. + Defaults to 3. + num_ans_candidates (int): Number of answer candidates, used to filter + out answers with low probability. Only support when + inference_method=='rank'. Defaults to 128. + loss (dict or nn.Module): Config of loss or module of loss. Defaults to + ``nn.CrossEntropyLoss(reduction='none', ignore_index=-100)``. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + answer_list_path (str, optional): Path to `answer_list.json` + (json file of a answer list). Required when + inference_method=='rank'. + + + TODO: `mmcls.LabelSmoothLoss` has not support `ignore_index` param. + Now using `nn.CrossEntropyLoss`, without label_smoothing, in order to + maintain compatibility with torch < 1.10.0 + """ + + def __init__( + self, + decoder: dict, + inference_method: str = 'generate', + num_beams: int = 3, + num_ans_candidates: int = 128, + loss: Union[dict, nn.Module] = nn.CrossEntropyLoss( + reduction='none', ignore_index=-100), + init_cfg: Optional[dict] = None, + answer_list_path: Optional[str] = None, + ) -> None: + + super(VQAGenerationHead, self).__init__(init_cfg=init_cfg) + self.decoder = MODELS.build(decoder) + + if inference_method == 'generate': + assert isinstance(num_beams, int), \ + 'for VQA `generate` mode, `num_beams` must be a int.' + self.num_beams = num_beams + self.num_ans_candidates = None + self.answer_list = None + + elif inference_method == 'rank': + assert isinstance(num_ans_candidates, int), \ + 'for VQA `rank` mode, `num_ans_candidates` must be a int.' + assert isinstance(answer_list_path, str), \ + 'for VQA `rank` mode, `answer_list_path` must be set as ' \ + 'the path to `answer_list.json`.' + self.num_beams = None + self.answer_list = mmengine.load(answer_list_path) + if isinstance(self.answer_list, dict): + self.answer_list = list(self.answer_list.keys()) + assert isinstance(self.answer_list, list) and all( + isinstance(item, str) for item in self.answer_list), \ + 'for VQA `rank` mode, `answer_list.json` must be a list of str' + self.num_ans_candidates = min(num_ans_candidates, + len(self.answer_list)) + + else: + raise AssertionError( + 'for VQA, `inference_method` must be "generate" or "rank", ' + 'got {}.'.format(inference_method)) + + self.inference_method = inference_method + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + + def forward(self, feats: dict): + prediction_logits = self.decoder( + feats['answer_input_ids'], + attention_mask=feats['answer_attention_mask'], + encoder_hidden_states=feats['question_states'], + encoder_attention_mask=feats['question_atts'], + labels=feats['answer_targets'], + return_dict=True, + return_logits=True, # directly return logits, not computing loss + reduction='none', + ) + return prediction_logits + + def loss(self, feats: dict, data_samples=None): + """Calculate losses from the extracted features. + + Args: + feats (dict): The features extracted from the backbone. + data_samples (List[BaseDataElement]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + shifted_prediction_scores = self(feats) + labels = feats['answer_targets'] + lm_loss = None + + # we are doing next-token prediction; + # shift prediction scores and input ids by one + labels = labels[:, 1:].contiguous() + lm_loss = self.loss_module( + shifted_prediction_scores.view(-1, + self.decoder.med_config.vocab_size), + labels.view(-1)) + lm_loss = lm_loss.view(shifted_prediction_scores.size(0), -1).sum(1) + # compute weighted loss + losses = dict() + loss = feats['answer_weight'] * lm_loss + loss = loss.sum() / feats['batch_size'] + losses['vqa_loss'] = loss + + return losses + + def predict_rank(self, feats: dict, data_samples=None): + """Predict rank in a close-set answer list.""" + question_states = feats['multimodal_embeds'] + question_atts = feats['question_atts'] + answer_candidates = feats['answer_candidates'] + assert answer_candidates is not None + + answer_ids = answer_candidates.input_ids + answer_atts = answer_candidates.attention_mask + num_ques = question_states.size(0) + start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token + + start_output = self.decoder( + start_ids, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + return_dict=True, + reduction='none', + ) + logits = start_output.logits[:, 0, :] # first token's logit + + # topk_probs: top-k probability + # topk_ids: [num_question, k] + answer_first_token = answer_ids[:, 1] + prob_first_token = F.softmax( + logits, dim=1).index_select( + dim=1, index=answer_first_token) + topk_probs, topk_ids = prob_first_token.topk( + self.num_ans_candidates, dim=1) + + # answer input: [num_question*k, answer_len] + input_ids = [] + input_atts = [] + for b, topk_id in enumerate(topk_ids): + input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) + input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) + input_ids = torch.cat(input_ids, dim=0) + input_atts = torch.cat(input_atts, dim=0) + + targets_ids = input_ids.masked_fill(input_ids == feats['pad_token_id'], + -100) + + def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate([ + init_dim * np.arange(n_tile) + i for i in range(init_dim) + ])) + return torch.index_select(x, dim, order_index.to(x.device)) + + # repeat encoder's output for top-k answers + question_states = tile(question_states, 0, self.num_ans_candidates) + question_atts = tile(question_atts, 0, self.num_ans_candidates) + + output = self.decoder( + input_ids, + attention_mask=input_atts, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=targets_ids, + return_dict=True, + reduction='none', + ) + + log_probs_sum = -output.loss + log_probs_sum = log_probs_sum.view(num_ques, self.num_ans_candidates) + + max_topk_ids = log_probs_sum.argmax(dim=1) + max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids] + + answers = [self.answer_list[max_id] for max_id in max_ids] + + return answers + + def predict_generate(self, feats: dict, data_samples=None): + """Predict answers in a generation manner.""" + device = feats['multimodal_embeds'].device + question_states = feats['multimodal_embeds'] + question_atts = torch.ones( + question_states.size()[:-1], dtype=torch.long).to(device) + model_kwargs = { + 'encoder_hidden_states': question_states, + 'encoder_attention_mask': question_atts + } + + bos_ids = torch.full((feats['multimodal_embeds'].shape[0], 1), + fill_value=feats['bos_token_id'], + device=device) + + outputs = self.decoder.generate( + input_ids=bos_ids, + max_length=10, + min_length=1, + num_beams=self.num_beams, + eos_token_id=feats['sep_token_id'], + pad_token_id=feats['pad_token_id'], + **model_kwargs) + + return outputs + + def predict(self, feats: dict, data_samples=None): + """Predict results from the extracted features.""" + if self.inference_method == 'generate': + return self.predict_generate(feats, data_samples) + elif self.inference_method == 'rank': + return self.predict_rank(feats, data_samples) diff --git a/mmpretrain/models/losses/__init__.py b/mmpretrain/models/losses/__init__.py new file mode 100644 index 0000000..b1b2ed7 --- /dev/null +++ b/mmpretrain/models/losses/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .asymmetric_loss import AsymmetricLoss, asymmetric_loss +from .cae_loss import CAELoss +from .cosine_similarity_loss import CosineSimilarityLoss +from .cross_correlation_loss import CrossCorrelationLoss +from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, + cross_entropy) +from .focal_loss import FocalLoss, sigmoid_focal_loss +from .label_smooth_loss import LabelSmoothLoss +from .reconstruction_loss import PixelReconstructionLoss +from .seesaw_loss import SeesawLoss +from .swav_loss import SwAVLoss +from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss, + weighted_loss) + +__all__ = [ + 'asymmetric_loss', + 'AsymmetricLoss', + 'cross_entropy', + 'binary_cross_entropy', + 'CrossEntropyLoss', + 'reduce_loss', + 'weight_reduce_loss', + 'LabelSmoothLoss', + 'weighted_loss', + 'FocalLoss', + 'sigmoid_focal_loss', + 'convert_to_one_hot', + 'SeesawLoss', + 'CAELoss', + 'CosineSimilarityLoss', + 'CrossCorrelationLoss', + 'PixelReconstructionLoss', + 'SwAVLoss', +] diff --git a/mmpretrain/models/losses/asymmetric_loss.py b/mmpretrain/models/losses/asymmetric_loss.py new file mode 100644 index 0000000..dcc9707 --- /dev/null +++ b/mmpretrain/models/losses/asymmetric_loss.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .utils import convert_to_one_hot, weight_reduce_loss + + +def asymmetric_loss(pred, + target, + weight=None, + gamma_pos=1.0, + gamma_neg=4.0, + clip=0.05, + reduction='mean', + avg_factor=None, + use_sigmoid=True, + eps=1e-8): + r"""asymmetric loss. + + Please refer to the `paper `__ for + details. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction with + shape (N, \*). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, ). Defaults to None. + gamma_pos (float): positive focusing parameter. Defaults to 0.0. + gamma_neg (float): Negative focusing parameter. We usually set + gamma_neg > gamma_pos. Defaults to 4.0. + clip (float, optional): Probability margin. Defaults to 0.05. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". If reduction is 'none' , loss + is same shape as pred and label. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + use_sigmoid (bool): Whether the prediction uses sigmoid instead + of softmax. Defaults to True. + eps (float): The minimum value of the argument of logarithm. Defaults + to 1e-8. + + Returns: + torch.Tensor: Loss. + """ + assert pred.shape == \ + target.shape, 'pred and target should be in the same shape.' + + if use_sigmoid: + pred_sigmoid = pred.sigmoid() + else: + pred_sigmoid = nn.functional.softmax(pred, dim=-1) + + target = target.type_as(pred) + + if clip and clip > 0: + pt = (1 - pred_sigmoid + + clip).clamp(max=1) * (1 - target) + pred_sigmoid * target + else: + pt = (1 - pred_sigmoid) * (1 - target) + pred_sigmoid * target + asymmetric_weight = (1 - pt).pow(gamma_pos * target + gamma_neg * + (1 - target)) + loss = -torch.log(pt.clamp(min=eps)) * asymmetric_weight + if weight is not None: + assert weight.dim() == 1 + weight = weight.float() + if pred.dim() > 1: + weight = weight.reshape(-1, 1) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class AsymmetricLoss(nn.Module): + """asymmetric loss. + + Args: + gamma_pos (float): positive focusing parameter. + Defaults to 0.0. + gamma_neg (float): Negative focusing parameter. We + usually set gamma_neg > gamma_pos. Defaults to 4.0. + clip (float, optional): Probability margin. Defaults to 0.05. + reduction (str): The method used to reduce the loss into + a scalar. + loss_weight (float): Weight of loss. Defaults to 1.0. + use_sigmoid (bool): Whether the prediction uses sigmoid instead + of softmax. Defaults to True. + eps (float): The minimum value of the argument of logarithm. Defaults + to 1e-8. + """ + + def __init__(self, + gamma_pos=0.0, + gamma_neg=4.0, + clip=0.05, + reduction='mean', + loss_weight=1.0, + use_sigmoid=True, + eps=1e-8): + super(AsymmetricLoss, self).__init__() + self.gamma_pos = gamma_pos + self.gamma_neg = gamma_neg + self.clip = clip + self.reduction = reduction + self.loss_weight = loss_weight + self.use_sigmoid = use_sigmoid + self.eps = eps + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + r"""asymmetric loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction + with shape (N, \*), N or (N,1). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, \*). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The method used to reduce the + loss into a scalar. Options are "none", "mean" and "sum". + Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if target.dim() == 1 or (target.dim() == 2 and target.shape[1] == 1): + target = convert_to_one_hot(target.view(-1, 1), pred.shape[-1]) + loss_cls = self.loss_weight * asymmetric_loss( + pred, + target, + weight, + gamma_pos=self.gamma_pos, + gamma_neg=self.gamma_neg, + clip=self.clip, + reduction=reduction, + avg_factor=avg_factor, + use_sigmoid=self.use_sigmoid, + eps=self.eps) + return loss_cls diff --git a/mmpretrain/models/losses/cae_loss.py b/mmpretrain/models/losses/cae_loss.py new file mode 100644 index 0000000..1dc081b --- /dev/null +++ b/mmpretrain/models/losses/cae_loss.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CAELoss(BaseModule): + """Loss function for CAE. + + Compute the align loss and the main loss. + + Args: + lambd (float): The weight for the align loss. + """ + + def __init__(self, lambd: float) -> None: + super().__init__() + self.lambd = lambd + self.loss_cross_entropy = nn.CrossEntropyLoss() + self.loss_mse = nn.MSELoss() + + def forward( + self, logits: torch.Tensor, target: torch.Tensor, + latent_pred: torch.Tensor, + latent_target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function of CAE Loss. + + Args: + logits (torch.Tensor): The outputs from the decoder. + target (torch.Tensor): The targets generated by dalle. + latent_pred (torch.Tensor): The latent prediction from the + regressor. + latent_target (torch.Tensor): The latent target from the teacher + network. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The main loss and align loss. + """ + loss_main = self.loss_cross_entropy(logits, target) + loss_align = self.loss_mse(latent_pred, + latent_target.detach()) * self.lambd + + return loss_main, loss_align diff --git a/mmpretrain/models/losses/cosine_similarity_loss.py b/mmpretrain/models/losses/cosine_similarity_loss.py new file mode 100644 index 0000000..f0a5931 --- /dev/null +++ b/mmpretrain/models/losses/cosine_similarity_loss.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Optional + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CosineSimilarityLoss(BaseModule): + """Cosine similarity loss function. + + Compute the similarity between two features and optimize that similarity as + loss. + + Args: + shift_factor (float): The shift factor of cosine similarity. + Default: 0.0. + scale_factor (float): The scale factor of cosine similarity. + Default: 1.0. + """ + + def __init__(self, + shift_factor: float = 0.0, + scale_factor: float = 1.0) -> None: + super().__init__() + self.shift_factor = shift_factor + self.scale_factor = scale_factor + + def forward(self, + pred: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward function of cosine similarity loss. + + Args: + pred (torch.Tensor): The predicted features. + target (torch.Tensor): The target features. + + Returns: + torch.Tensor: The cosine similarity loss. + """ + pred_norm = nn.functional.normalize(pred, dim=-1) + target_norm = nn.functional.normalize(target, dim=-1) + loss = self.shift_factor - self.scale_factor * ( + pred_norm * target_norm).sum(dim=-1) + + if mask is None: + loss = loss.mean() + else: + loss = (loss * mask).sum() / mask.sum() + return loss diff --git a/mmpretrain/models/losses/cross_correlation_loss.py b/mmpretrain/models/losses/cross_correlation_loss.py new file mode 100644 index 0000000..d26ce3d --- /dev/null +++ b/mmpretrain/models/losses/cross_correlation_loss.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CrossCorrelationLoss(BaseModule): + """Cross correlation loss function. + + Compute the on-diagnal and off-diagnal loss. + + Args: + lambd (float): The weight for the off-diag loss. + """ + + def __init__(self, lambd: float = 0.0051) -> None: + super().__init__() + self.lambd = lambd + + def forward(self, cross_correlation_matrix: torch.Tensor) -> torch.Tensor: + """Forward function of cross correlation loss. + + Args: + cross_correlation_matrix (torch.Tensor): The cross correlation + matrix. + + Returns: + torch.Tensor: cross correlation loss. + """ + # loss + on_diag = torch.diagonal(cross_correlation_matrix).add_(-1).pow_( + 2).sum() + off_diag = self.off_diagonal(cross_correlation_matrix).pow_(2).sum() + loss = on_diag + self.lambd * off_diag + return loss + + def off_diagonal(self, x: torch.Tensor) -> torch.Tensor: + """Rreturn a flattened view of the off-diagonal elements of a square + matrix.""" + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() diff --git a/mmpretrain/models/losses/cross_entropy_loss.py b/mmpretrain/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000..5d418be --- /dev/null +++ b/mmpretrain/models/losses/cross_entropy_loss.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from .utils import weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None): + """Calculate the CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + label (torch.Tensor): The gt label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. + + Returns: + torch.Tensor: The calculated loss + """ + # element-wise losses + loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none') + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def soft_cross_entropy(pred, + label, + weight=None, + reduction='mean', + class_weight=None, + avg_factor=None): + """Calculate the Soft CrossEntropy loss. The label can be float. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + label (torch.Tensor): The gt label of the prediction with shape (N, C). + When using "mixup", the label can be float. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. + + Returns: + torch.Tensor: The calculated loss + """ + # element-wise losses + loss = -label * F.log_softmax(pred, dim=-1) + if class_weight is not None: + loss *= class_weight + loss = loss.sum(dim=-1) + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + pos_weight=None): + r"""Calculate the binary CrossEntropy loss with logits. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + label (torch.Tensor): The gt label with shape (N, \*). + weight (torch.Tensor, optional): Element-wise weight of loss with shape + (N, ). Defaults to None. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". If reduction is 'none' , loss + is same shape as pred and label. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. + pos_weight (torch.Tensor, optional): The positive weight for each + class with shape (C), C is the number of classes. Default None. + + Returns: + torch.Tensor: The calculated loss + """ + # Ensure that the size of class_weight is consistent with pred and label to + # avoid automatic boracast, + assert pred.dim() == label.dim() + + if class_weight is not None: + N = pred.size()[0] + class_weight = class_weight.repeat(N, 1) + loss = F.binary_cross_entropy_with_logits( + pred, + label.float(), # only accepts float type tensor + weight=class_weight, + pos_weight=pos_weight, + reduction='none') + + # apply weights and do the reduction + if weight is not None: + assert weight.dim() == 1 + weight = weight.float() + if pred.dim() > 1: + weight = weight.reshape(-1, 1) + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + return loss + + +@MODELS.register_module() +class CrossEntropyLoss(nn.Module): + """Cross entropy loss. + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_soft (bool): Whether to use the soft version of CrossEntropyLoss. + Defaults to False. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". Defaults to 'mean'. + loss_weight (float): Weight of the loss. Defaults to 1.0. + class_weight (List[float], optional): The weight for each class with + shape (C), C is the number of classes. Default None. + pos_weight (List[float], optional): The positive weight for each + class with shape (C), C is the number of classes. Only enabled in + BCE loss when ``use_sigmoid`` is True. Default None. + """ + + def __init__(self, + use_sigmoid=False, + use_soft=False, + reduction='mean', + loss_weight=1.0, + class_weight=None, + pos_weight=None): + super(CrossEntropyLoss, self).__init__() + self.use_sigmoid = use_sigmoid + self.use_soft = use_soft + assert not ( + self.use_soft and self.use_sigmoid + ), 'use_sigmoid and use_soft could not be set simultaneously' + + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + self.pos_weight = pos_weight + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_soft: + self.cls_criterion = soft_cross_entropy + else: + self.cls_criterion = cross_entropy + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + + # only BCE loss has pos_weight + if self.pos_weight is not None and self.use_sigmoid: + pos_weight = cls_score.new_tensor(self.pos_weight) + kwargs.update({'pos_weight': pos_weight}) + else: + pos_weight = None + + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls diff --git a/mmpretrain/models/losses/focal_loss.py b/mmpretrain/models/losses/focal_loss.py new file mode 100644 index 0000000..9d2cf50 --- /dev/null +++ b/mmpretrain/models/losses/focal_loss.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from .utils import convert_to_one_hot, weight_reduce_loss + + +def sigmoid_focal_loss(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + r"""Sigmoid focal loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction with + shape (N, \*). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, ). Defaults to None. + gamma (float): The gamma for calculating the modulating factor. + Defaults to 2.0. + alpha (float): A balanced form for Focal Loss. Defaults to 0.25. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". If reduction is 'none' , + loss is same shape as pred and label. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + assert pred.shape == \ + target.shape, 'pred and target should be in the same shape.' + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * pt.pow(gamma) + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + if weight is not None: + assert weight.dim() == 1 + weight = weight.float() + if pred.dim() > 1: + weight = weight.reshape(-1, 1) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class FocalLoss(nn.Module): + """Focal loss. + + Args: + gamma (float): Focusing parameter in focal loss. + Defaults to 2.0. + alpha (float): The parameter in balanced form of focal + loss. Defaults to 0.25. + reduction (str): The method used to reduce the loss into + a scalar. Options are "none" and "mean". Defaults to 'mean'. + loss_weight (float): Weight of loss. Defaults to 1.0. + """ + + def __init__(self, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=1.0): + + super(FocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + r"""Sigmoid focal loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The ground truth label of the prediction + with shape (N, \*), N or (N,1). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, \*). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The method used to reduce the + loss into a scalar. Options are "none", "mean" and "sum". + Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if target.dim() == 1 or (target.dim() == 2 and target.shape[1] == 1): + target = convert_to_one_hot(target.view(-1, 1), pred.shape[-1]) + loss_cls = self.loss_weight * sigmoid_focal_loss( + pred, + target, + weight, + gamma=self.gamma, + alpha=self.alpha, + reduction=reduction, + avg_factor=avg_factor) + return loss_cls diff --git a/mmpretrain/models/losses/label_smooth_loss.py b/mmpretrain/models/losses/label_smooth_loss.py new file mode 100644 index 0000000..f117df3 --- /dev/null +++ b/mmpretrain/models/losses/label_smooth_loss.py @@ -0,0 +1,177 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .cross_entropy_loss import CrossEntropyLoss +from .utils import convert_to_one_hot + + +@MODELS.register_module() +class LabelSmoothLoss(nn.Module): + r"""Initializer for the label smoothed cross entropy loss. + + Refers to `Rethinking the Inception Architecture for Computer Vision + `_ + + This decreases gap between output scores and encourages generalization. + Labels provided to forward can be one-hot like vectors (NxC) or class + indices (Nx1). + And this accepts linear combination of one-hot like labels from mixup or + cutmix except multi-label task. + + Args: + label_smooth_val (float): The degree of label smoothing. + num_classes (int, optional): Number of classes. Defaults to None. + mode (str): Refers to notes, Options are 'original', 'classy_vision', + 'multi_label'. Defaults to 'original'. + use_sigmoid (bool, optional): Whether the prediction uses sigmoid of + softmax. Defaults to None, which means to use sigmoid in + "multi_label" mode and not use in other modes. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". Defaults to 'mean'. + loss_weight (float): Weight of the loss. Defaults to 1.0. + + Notes: + - if the mode is **"original"**, this will use the same label smooth + method as the original paper as: + + .. math:: + (1-\epsilon)\delta_{k, y} + \frac{\epsilon}{K} + + where :math:`\epsilon` is the ``label_smooth_val``, :math:`K` is the + ``num_classes`` and :math:`\delta_{k, y}` is Dirac delta, which + equals 1 for :math:`k=y` and 0 otherwise. + + - if the mode is **"classy_vision"**, this will use the same label + smooth method as the facebookresearch/ClassyVision repo as: + + .. math:: + \frac{\delta_{k, y} + \epsilon/K}{1+\epsilon} + + - if the mode is **"multi_label"**, this will accept labels from + multi-label task and smoothing them as: + + .. math:: + (1-2\epsilon)\delta_{k, y} + \epsilon + """ + + def __init__(self, + label_smooth_val, + num_classes=None, + use_sigmoid=None, + mode='original', + reduction='mean', + loss_weight=1.0, + class_weight=None, + pos_weight=None): + super().__init__() + self.num_classes = num_classes + self.loss_weight = loss_weight + + assert (isinstance(label_smooth_val, float) + and 0 <= label_smooth_val < 1), \ + f'LabelSmoothLoss accepts a float label_smooth_val ' \ + f'over [0, 1), but gets {label_smooth_val}' + self.label_smooth_val = label_smooth_val + + accept_reduction = {'none', 'mean', 'sum'} + assert reduction in accept_reduction, \ + f'LabelSmoothLoss supports reduction {accept_reduction}, ' \ + f'but gets {mode}.' + self.reduction = reduction + + accept_mode = {'original', 'classy_vision', 'multi_label'} + assert mode in accept_mode, \ + f'LabelSmoothLoss supports mode {accept_mode}, but gets {mode}.' + self.mode = mode + + self._eps = label_smooth_val + if mode == 'classy_vision': + self._eps = label_smooth_val / (1 + label_smooth_val) + + if mode == 'multi_label': + if not use_sigmoid: + from mmengine.logging import MMLogger + MMLogger.get_current_instance().warning( + 'For multi-label tasks, please set `use_sigmoid=True` ' + 'to use binary cross entropy.') + self.smooth_label = self.multilabel_smooth_label + use_sigmoid = True if use_sigmoid is None else use_sigmoid + else: + self.smooth_label = self.original_smooth_label + use_sigmoid = False if use_sigmoid is None else use_sigmoid + + self.ce = CrossEntropyLoss( + use_sigmoid=use_sigmoid, + use_soft=not use_sigmoid, + reduction=reduction, + class_weight=class_weight, + pos_weight=pos_weight) + + def generate_one_hot_like_label(self, label): + """This function takes one-hot or index label vectors and computes one- + hot like label vectors (float)""" + # check if targets are inputted as class integers + if label.dim() == 1 or (label.dim() == 2 and label.shape[1] == 1): + label = convert_to_one_hot(label.view(-1, 1), self.num_classes) + return label.float() + + def original_smooth_label(self, one_hot_like_label): + assert self.num_classes > 0 + smooth_label = one_hot_like_label * (1 - self._eps) + smooth_label += self._eps / self.num_classes + return smooth_label + + def multilabel_smooth_label(self, one_hot_like_label): + assert self.num_classes > 0 + smooth_label = torch.full_like(one_hot_like_label, self._eps) + smooth_label.masked_fill_(one_hot_like_label > 0, 1 - self._eps) + return smooth_label + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + r"""Label smooth loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + label (torch.Tensor): The ground truth label of the prediction + with shape (N, \*). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, \*). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The method used to reduce the + loss into a scalar. Options are "none", "mean" and "sum". + Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + if self.num_classes is not None: + assert self.num_classes == cls_score.shape[1], \ + f'num_classes should equal to cls_score.shape[1], ' \ + f'but got num_classes: {self.num_classes} and ' \ + f'cls_score.shape[1]: {cls_score.shape[1]}' + else: + self.num_classes = cls_score.shape[1] + + one_hot_like_label = self.generate_one_hot_like_label(label=label) + assert one_hot_like_label.shape == cls_score.shape, \ + f'LabelSmoothLoss requires output and target ' \ + f'to be same shape, but got output.shape: {cls_score.shape} ' \ + f'and target.shape: {one_hot_like_label.shape}' + + smoothed_label = self.smooth_label(one_hot_like_label) + return self.loss_weight * self.ce.forward( + cls_score, + smoothed_label, + weight=weight, + avg_factor=avg_factor, + reduction_override=reduction_override, + **kwargs) diff --git a/mmpretrain/models/losses/reconstruction_loss.py b/mmpretrain/models/losses/reconstruction_loss.py new file mode 100644 index 0000000..40e6bfd --- /dev/null +++ b/mmpretrain/models/losses/reconstruction_loss.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class PixelReconstructionLoss(BaseModule): + """Loss for the reconstruction of pixel in Masked Image Modeling. + + This module measures the distance between the target image and the + reconstructed image and compute the loss to optimize the model. Currently, + This module only provides L1 and L2 loss to penalize the reconstructed + error. In addition, a mask can be passed in the ``forward`` function to + only apply loss on visible region, like that in MAE. + + Args: + criterion (str): The loss the penalize the reconstructed error. + Currently, only supports L1 and L2 loss + channel (int, optional): The number of channels to average the + reconstruction loss. If not None, the reconstruction loss + will be divided by the channel. Defaults to None. + """ + + def __init__(self, criterion: str, channel: Optional[int] = None) -> None: + super().__init__() + + if criterion == 'L1': + self.penalty = torch.nn.L1Loss(reduction='none') + elif criterion == 'L2': + self.penalty = torch.nn.MSELoss(reduction='none') + else: + raise NotImplementedError(f'Currently, PixelReconstructionLoss \ + only supports L1 and L2 loss, but get {criterion}') + + self.channel = channel if channel is not None else 1 + + def forward(self, + pred: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward function to compute the reconstrction loss. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + loss = self.penalty(pred, target) + + # if the dim of the loss is 3, take the average of the loss + # along the last dim + if len(loss.shape) == 3: + loss = loss.mean(dim=-1) + + if mask is None: + loss = loss.mean() + else: + loss = (loss * mask).sum() / mask.sum() / self.channel + + return loss diff --git a/mmpretrain/models/losses/seesaw_loss.py b/mmpretrain/models/losses/seesaw_loss.py new file mode 100644 index 0000000..4aaaa45 --- /dev/null +++ b/mmpretrain/models/losses/seesaw_loss.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# migrate from mmdetection with modifications +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS +from .utils import weight_reduce_loss + + +def seesaw_ce_loss(cls_score, + labels, + weight, + cum_samples, + num_classes, + p, + q, + eps, + reduction='mean', + avg_factor=None): + """Calculate the Seesaw CrossEntropy loss. + + Args: + cls_score (torch.Tensor): The prediction with shape (N, C), + C is the number of classes. + labels (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor): Sample-wise loss weight. + cum_samples (torch.Tensor): Cumulative samples for each category. + num_classes (int): The number of classes. + p (float): The ``p`` in the mitigation factor. + q (float): The ``q`` in the compenstation factor. + eps (float): The minimal value of divisor to smooth + the computation of compensation factor + reduction (str, optional): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + + Returns: + torch.Tensor: The calculated loss + """ + assert cls_score.size(-1) == num_classes + assert len(cum_samples) == num_classes + + onehot_labels = F.one_hot(labels, num_classes) + seesaw_weights = cls_score.new_ones(onehot_labels.size()) + + # mitigation factor + if p > 0: + sample_ratio_matrix = cum_samples[None, :].clamp( + min=1) / cum_samples[:, None].clamp(min=1) + index = (sample_ratio_matrix < 1.0).float() + sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index + ) # M_{ij} + mitigation_factor = sample_weights[labels.long(), :] + seesaw_weights = seesaw_weights * mitigation_factor + + # compensation factor + if q > 0: + scores = F.softmax(cls_score.detach(), dim=1) + self_scores = scores[ + torch.arange(0, len(scores)).to(scores.device).long(), + labels.long()] + score_matrix = scores / self_scores[:, None].clamp(min=eps) + index = (score_matrix > 1.0).float() + compensation_factor = score_matrix.pow(q) * index + (1 - index) + seesaw_weights = seesaw_weights * compensation_factor + + cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels)) + + loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none') + + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + return loss + + +@MODELS.register_module() +class SeesawLoss(nn.Module): + """Implementation of seesaw loss. + + Refers to `Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021) + `_ + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid of softmax. + Only False is supported. Defaults to False. + p (float): The ``p`` in the mitigation factor. + Defaults to 0.8. + q (float): The ``q`` in the compenstation factor. + Defaults to 2.0. + num_classes (int): The number of classes. + Defaults to 1000 for the ImageNet dataset. + eps (float): The minimal value of divisor to smooth + the computation of compensation factor, default to 1e-2. + reduction (str): The method that reduces the loss to a scalar. + Options are "none", "mean" and "sum". Defaults to "mean". + loss_weight (float): The weight of the loss. Defaults to 1.0 + """ + + def __init__(self, + use_sigmoid=False, + p=0.8, + q=2.0, + num_classes=1000, + eps=1e-2, + reduction='mean', + loss_weight=1.0): + super(SeesawLoss, self).__init__() + assert not use_sigmoid, '`use_sigmoid` is not supported' + self.use_sigmoid = False + self.p = p + self.q = q + self.num_classes = num_classes + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + self.cls_criterion = seesaw_ce_loss + + # cumulative samples for each category + self.register_buffer('cum_samples', + torch.zeros(self.num_classes, dtype=torch.float)) + + def forward(self, + cls_score, + labels, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + cls_score (torch.Tensor): The prediction with shape (N, C). + labels (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + Returns: + torch.Tensor: The calculated loss + """ + assert reduction_override in (None, 'none', 'mean', 'sum'), \ + f'The `reduction_override` should be one of (None, "none", ' \ + f'"mean", "sum"), but get "{reduction_override}".' + assert cls_score.size(0) == labels.view(-1).size(0), \ + f'Expected `labels` shape [{cls_score.size(0)}], ' \ + f'but got {list(labels.size())}' + reduction = ( + reduction_override if reduction_override else self.reduction) + assert cls_score.size(-1) == self.num_classes, \ + f'The channel number of output ({cls_score.size(-1)}) does ' \ + f'not match the `num_classes` of seesaw loss ({self.num_classes}).' + + # accumulate the samples for each category + unique_labels = labels.unique() + for u_l in unique_labels: + inds_ = labels == u_l.item() + self.cum_samples[u_l] += inds_.sum() + + if weight is not None: + weight = weight.float() + else: + weight = labels.new_ones(labels.size(), dtype=torch.float) + + # calculate loss_cls_classes + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, labels, weight, self.cum_samples, self.num_classes, + self.p, self.q, self.eps, reduction, avg_factor) + + return loss_cls diff --git a/mmpretrain/models/losses/swav_loss.py b/mmpretrain/models/losses/swav_loss.py new file mode 100644 index 0000000..c7dbb78 --- /dev/null +++ b/mmpretrain/models/losses/swav_loss.py @@ -0,0 +1,190 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from mmengine.dist import all_reduce +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@torch.no_grad() +def distributed_sinkhorn(out: torch.Tensor, sinkhorn_iterations: int, + world_size: int, epsilon: float) -> torch.Tensor: + """Apply the distributed sinknorn optimization on the scores matrix to find + the assignments. + + This function is modified from + https://github.com/facebookresearch/swav/blob/main/main_swav.py + + Args: + out (torch.Tensor): The scores matrix + sinkhorn_iterations (int): Number of iterations in Sinkhorn-Knopp + algorithm. + world_size (int): The world size of the process group. + epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm. + + Returns: + torch.Tensor: Output of sinkhorn algorithm. + """ + eps_num_stab = 1e-12 + Q = torch.exp(out / epsilon).t( + ) # Q is K-by-B for consistency with notations from our paper + B = Q.shape[1] * world_size # number of samples to assign + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + all_reduce(sum_Q) + Q /= sum_Q + + for it in range(sinkhorn_iterations): + # normalize each row: total weight per prototype must be 1/K + u = torch.sum(Q, dim=1, keepdim=True) + if len(torch.nonzero(u == 0)) > 0: + Q += eps_num_stab + u = torch.sum(Q, dim=1, keepdim=True, dtype=Q.dtype) + all_reduce(u) + Q /= u + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + +class MultiPrototypes(BaseModule): + """Multi-prototypes for SwAV head. + + Args: + output_dim (int): The output dim from SwAV neck. + num_prototypes (List[int]): The number of prototypes needed. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + output_dim: int, + num_prototypes: List[int], + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + assert isinstance(num_prototypes, list) + self.num_heads = len(num_prototypes) + for i, k in enumerate(num_prototypes): + self.add_module('prototypes' + str(i), + nn.Linear(output_dim, k, bias=False)) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Run forward for every prototype.""" + out = [] + for i in range(self.num_heads): + out.append(getattr(self, 'prototypes' + str(i))(x)) + return out + + +@MODELS.register_module() +class SwAVLoss(BaseModule): + """The Loss for SwAV. + + This Loss contains clustering and sinkhorn algorithms to compute Q codes. + Part of the code is borrowed from `script + `_. + The queue is built in `engine/hooks/swav_hook.py`. + + Args: + feat_dim (int): feature dimension of the prototypes. + sinkhorn_iterations (int): number of iterations in Sinkhorn-Knopp + algorithm. Defaults to 3. + epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm. + Defaults to 0.05. + temperature (float): temperature parameter in training loss. + Defaults to 0.1. + crops_for_assign (List[int]): list of crops id used for computing + assignments. Defaults to [0, 1]. + num_crops (List[int]): list of number of crops. Defaults to [2]. + num_prototypes (int): number of prototypes. Defaults to 3000. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + feat_dim: int, + sinkhorn_iterations: int = 3, + epsilon: float = 0.05, + temperature: float = 0.1, + crops_for_assign: List[int] = [0, 1], + num_crops: List[int] = [2], + num_prototypes: int = 3000, + init_cfg: Optional[Union[List[dict], dict]] = None): + super().__init__(init_cfg=init_cfg) + self.sinkhorn_iterations = sinkhorn_iterations + self.epsilon = epsilon + self.temperature = temperature + self.crops_for_assign = crops_for_assign + self.num_crops = num_crops + self.use_queue = False + self.queue = None + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + # prototype layer + self.prototypes = None + if isinstance(num_prototypes, list): + self.prototypes = MultiPrototypes(feat_dim, num_prototypes) + elif num_prototypes > 0: + self.prototypes = nn.Linear(feat_dim, num_prototypes, bias=False) + assert self.prototypes is not None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function of SwAV loss. + + Args: + x (torch.Tensor): NxC input features. + Returns: + torch.Tensor: The returned loss. + """ + # normalize the prototypes + with torch.no_grad(): + w = self.prototypes.weight.data.clone() + w = nn.functional.normalize(w, dim=1, p=2) + self.prototypes.weight.copy_(w) + + embedding, output = x, self.prototypes(x) + embedding = embedding.detach() + + bs = int(embedding.size(0) / sum(self.num_crops)) + loss = 0 + for i, crop_id in enumerate(self.crops_for_assign): + with torch.no_grad(): + out = output[bs * crop_id:bs * (crop_id + 1)].detach() + # time to use the queue + if self.queue is not None: + if self.use_queue or not torch.all(self.queue[i, + -1, :] == 0): + self.use_queue = True + out = torch.cat( + (torch.mm(self.queue[i], + self.prototypes.weight.t()), out)) + # fill the queue + self.queue[i, bs:] = self.queue[i, :-bs].clone() + self.queue[i, :bs] = embedding[crop_id * bs:(crop_id + 1) * + bs] + + # get assignments (batch_size * num_prototypes) + q = distributed_sinkhorn(out, self.sinkhorn_iterations, + self.world_size, self.epsilon)[-bs:] + + # cluster assignment prediction + subloss = 0 + for v in np.delete(np.arange(np.sum(self.num_crops)), crop_id): + x = output[bs * v:bs * (v + 1)] / self.temperature + subloss -= torch.mean( + torch.sum(q * nn.functional.log_softmax(x, dim=1), dim=1)) + loss += subloss / (np.sum(self.num_crops) - 1) + loss /= len(self.crops_for_assign) + return loss diff --git a/mmpretrain/models/losses/utils.py b/mmpretrain/models/losses/utils.py new file mode 100644 index 0000000..a65b68a --- /dev/null +++ b/mmpretrain/models/losses/utils.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch +import torch.nn.functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + ``loss_func(pred, target, **kwargs)``. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like ``loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)``. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper + + +def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: + """This function converts target class indices to one-hot vectors, given + the number of classes. + + Args: + targets (Tensor): The ground truth label of the prediction + with shape (N, 1) + classes (int): the number of classes. + + Returns: + Tensor: Processed loss values. + """ + assert (torch.max(targets).item() < + classes), 'Class Index must be less than number of classes' + one_hot_targets = F.one_hot( + targets.long().squeeze(-1), num_classes=classes) + return one_hot_targets diff --git a/mmpretrain/models/multimodal/__init__.py b/mmpretrain/models/multimodal/__init__.py new file mode 100644 index 0000000..e68504c --- /dev/null +++ b/mmpretrain/models/multimodal/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.utils.dependency import WITH_MULTIMODAL + +if WITH_MULTIMODAL: + from .blip import * # noqa: F401,F403 + from .blip2 import * # noqa: F401,F403 + from .chinese_clip import * # noqa: F401, F403 + from .clip import * # noqa: F401, F403 + from .flamingo import * # noqa: F401, F403 + from .llava import * # noqa: F401, F403 + from .minigpt4 import * # noqa: F401, F403 + from .ofa import * # noqa: F401, F403 + from .otter import * # noqa: F401, F403 + from .ram import * # noqa: F401, F403 +else: + from mmpretrain.registry import MODELS + from mmpretrain.utils.dependency import register_multimodal_placeholder + + register_multimodal_placeholder([ + 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption', + 'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo', + 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP', + 'CLIPZeroShot', 'RAM', 'RAMNormal', 'RAMOpenset' + ], MODELS) diff --git a/mmpretrain/models/multimodal/blip/__init__.py b/mmpretrain/models/multimodal/blip/__init__.py new file mode 100644 index 0000000..ebbc0da --- /dev/null +++ b/mmpretrain/models/multimodal/blip/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .blip_caption import BlipCaption +from .blip_grounding import BlipGrounding +from .blip_nlvr import BlipNLVR +from .blip_retrieval import BlipRetrieval +from .blip_vqa import BlipVQA +from .language_model import BertLMHeadModel, XBertEncoder, XBertLMHeadDecoder + +__all__ = [ + 'BertLMHeadModel', 'BlipCaption', 'BlipGrounding', 'BlipNLVR', + 'BlipRetrieval', 'BlipVQA', 'XBertEncoder', 'XBertLMHeadDecoder' +] diff --git a/mmpretrain/models/multimodal/blip/blip_caption.py b/mmpretrain/models/multimodal/blip/blip_caption.py new file mode 100644 index 0000000..9af3e24 --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_caption.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class BlipCaption(BaseModel): + """BLIP Caption. + + Args: + vision_encoder (dict): Encoder for extracting image features. + decoder_head (dict): The decoder head module to forward and + calculate loss from processed features. + tokenizer: (Optional[dict]): The config for tokenizer. + Defaults to None. + prompt (str): Prompt used for training and eval. + Defaults to ''. + max_txt_len (int): Max text length of input text. + num_captions (int): Number of captions to be generated for each image. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_encoder: dict, + decoder_head: dict, + tokenizer: Optional[dict] = None, + prompt: str = '', + max_txt_len: int = 20, + num_captions: int = 1, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(BlipCaption, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.visual_encoder = MODELS.build(vision_encoder) + self.seq_gen_head = MODELS.build(decoder_head) + + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + self.max_txt_len = max_txt_len + self.num_captions = num_captions + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + The method should accept two modes: "predict" and "loss": + + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): pre_processed img tensor (N, C, ...). + data_samples (List[DataSample], optional): Data samples with + additional infos. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict(self, images, data_samples=None, **kwargs): + """Predict captions from a batch of inputs. + + Args: + images (torch.Tensor): The input images tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + + Returns: + List[DataSample]: Return list of data samples. + """ + # prepare inputs for decoder generation. + image_embeds = self.visual_encoder(images)[0] + image_embeds = torch.repeat_interleave(image_embeds, self.num_captions, + 0) + + prompt = [self.prompt] * image_embeds.size(0) + prompt = self.tokenizer( + prompt, padding='longest', + return_tensors='pt').to(image_embeds.device) + + prompt.input_ids[:, 0] = self.tokenizer.bos_token_id + prompt.input_ids = prompt.input_ids[:, :-1] + + decoder_out = self.seq_gen_head.predict( + input_ids=prompt.input_ids, + encoder_hidden_states=image_embeds, + sep_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + output_attentions=True, + return_dict_in_generate=True, + ) + + decode_tokens = self.tokenizer.batch_decode( + decoder_out.sequences, skip_special_tokens=True) + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(len(decode_tokens))] + + for data_sample, decode_token in zip(data_samples, decode_tokens): + if data_sample is None: + data_sample = DataSample() + data_sample.pred_caption = decode_token[len(self.prompt):] + out_data_samples.append(data_sample) + + return out_data_samples + + def loss(self, images, data_samples): + """Calculate losses from a batch of images and data samples. + + Args: + images (torch.Tensor): The input images tensor with shape + (N, C, ...) in general. + data_samples (List[ImageTextDataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + image_embeds = self.visual_encoder(images)[0] + raw_text = [self.prompt + ds.gt_caption for ds in data_samples] + + text = self.tokenizer( + raw_text, + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors='pt', + ).to(image_embeds.device) + text.input_ids[:, 0] = self.tokenizer.bos_token_id + + # prepare targets for forwarding decoder + labels = text.input_ids.masked_fill( + text.input_ids == self.tokenizer.pad_token_id, -100) + labels[:, :self.prompt_length] = -100 + # forward decoder + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) + + losses = self.seq_gen_head.loss( + input_ids=text.input_ids, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + labels=labels, + ) + return losses diff --git a/mmpretrain/models/multimodal/blip/blip_grounding.py b/mmpretrain/models/multimodal/blip/blip_grounding.py new file mode 100644 index 0000000..cb08728 --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_grounding.py @@ -0,0 +1,248 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.model import BaseModel + +from mmpretrain.models.utils.box_utils import box_xyxy_to_cxcywh +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures.data_sample import DataSample + + +@MODELS.register_module() +class BlipGrounding(BaseModel): + """BLIP Grounding. + + Args: + visual_encoder (dict): Backbone for extracting image features. + text_encoder (dict): Backbone for extracting text features. + but we integrate the vqa text extractor + into the tokenizer part in datasets/transform/ + so we don't need text_backbone + multimodal_encoder (Optional[dict]): Backbone for extracting + multi-modal features. We apply this part as VQA fusion module. + neck (Optional[dict]): The neck module to process features from + backbone. Defaults to None. + head (Optional[Union[List[dict], dict]]): The head module to calculate + loss from processed features. See :mod:`mmpretrain.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + tokenizer: Optional[dict] = None, + visual_encoder: Optional[dict] = None, + text_encoder: Optional[dict] = None, + multimodal_encoder: Optional[dict] = None, + head: Optional[Union[List[dict], dict]] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(BlipGrounding, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.prompt = 'localize instance: ' + self.visual_encoder = MODELS.build(visual_encoder) + self.text_encoder = MODELS.build(text_encoder) + self.multimodal_encoder = MODELS.build(multimodal_encoder) + head.setdefault('tokenizer', self.tokenizer) + self.grounding_head = MODELS.build(head) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + The method should accept only one mode "loss": + + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor, tuple): The input tensor with shape + (N, C, ...) in general. + data_samples (List[VQADataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, images: torch.Tensor) -> torch.Tensor: + """Extract features from the input tensor with shape (N, C, ...). + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + Returns: + image_embeds (Tensor): The output features. + """ + image_embeds = self.visual_encoder(images)[0] + return image_embeds + + def loss( + self, + images: torch.Tensor, + data_samples=None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """generate train_loss from the input tensor and data_samples. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + data_samples (List[VQADataSample], optional): The annotation + data of every samples.. + + Returns: + Dict[torch.Tensor]: The losses features. + """ + + # extract image feature + image_embeds = self.extract_feat(images) + image_atts = image_embeds.new_ones( + image_embeds.size()[:-1], dtype=torch.long) + + raw_text = [] + box_targets = [] + for ds in data_samples: + + raw_text.append(ds.text) + box_t = copy.deepcopy(ds.box) * 1.0 + box_t[1] /= ds.img_shape[0] + box_t[3] /= ds.img_shape[0] + box_t[0] /= ds.img_shape[1] + box_t[2] /= ds.img_shape[1] + + box_targets.append(box_t) + + box_targets = image_embeds.new_tensor(np.stack(box_targets)) + box_targets = box_xyxy_to_cxcywh(box_targets) # xywh 0-1 + + text = self.tokenizer( + raw_text, + padding='longest', + truncation=True, + max_length=128, + return_tensors='pt', + ).to(image_embeds.device) + + text_embeds = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + mode='text', + return_dict=True) # bz, seq_len, hid + + # multimodal fusion + multimodal_embeds = self.multimodal_encoder( + encoder_embeds=text_embeds.last_hidden_state, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # put answer from data_samples into tensor form + losses = self.grounding_head.loss( + text_embedding=multimodal_embeds.last_hidden_state, + text_embedding_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + decoder_targets=box_targets, + ) + + return losses + + def predict(self, images, data_samples=None): + """""" + + # extract image feature + image_embeds = self.extract_feat(images) + image_atts = image_embeds.new_ones( + image_embeds.size()[:-1], dtype=torch.long) + + raw_text = [] + for ds in data_samples: + raw_text.append(ds.text) + + text = self.tokenizer( + raw_text, + padding='longest', + truncation=True, + max_length=128, + return_tensors='pt', + ).to(image_embeds.device) + + text_embeds = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + mode='text', + return_dict=True) # bz, seq_len, hid + + # multimodal fusion + multimodal_embeds = self.multimodal_encoder( + encoder_embeds=text_embeds.last_hidden_state, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # put answer from data_samples into tensor form + output_boxes = self.grounding_head.predict( + text_embedding=multimodal_embeds.last_hidden_state, + text_embedding_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + ) # xyxy 0-1 + + out_data_samples = [] + for bbox, data_sample, img in zip(output_boxes, data_samples, images): + if data_sample is None: + data_sample = DataSample() + + img_size = img.shape[-2:] + scale_factor = data_sample.get('scale_factor', (1, 1)) + bbox[0::2] = bbox[0::2] * img_size[1] / scale_factor[0] + bbox[1::2] = bbox[1::2] * img_size[0] / scale_factor[1] + bbox = bbox[None, :] + data_sample.pred_bboxes = bbox + + if 'gt_bboxes' in data_sample: + gt_bboxes = torch.Tensor(data_sample.get('gt_bboxes')) + gt_bboxes[:, 0::2] /= scale_factor[0] + gt_bboxes[:, 1::2] /= scale_factor[1] + data_sample.gt_bboxes = gt_bboxes + + out_data_samples.append(data_sample) + + return out_data_samples diff --git a/mmpretrain/models/multimodal/blip/blip_nlvr.py b/mmpretrain/models/multimodal/blip/blip_nlvr.py new file mode 100644 index 0000000..f96e3cc --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_nlvr.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER + + +@MODELS.register_module() +class BlipNLVR(BaseModel): + """BLIP NLVR. + + Args: + vision_backbone (dict): Backbone for extracting image features. + text_backbone (dict): Backbone for extracting text features. + but we integrate the vqa text extractor into the tokenizer part in + datasets/transform/ so we don't need text_backbone + multimodal_backbone (Optional[dict]): Backbone for extracting + multi-modal features. We apply this part as VQA fusion module. + neck (Optional[dict]): The neck module to process features from + backbone. Defaults to None. + head (Optional[dict]): The head module to calculate + loss from processed features. See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + tokenizer: (Optional[dict]): The config for tokenizer + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + multimodal_backbone: dict, + tokenizer: Optional[dict] = None, + max_txt_len: int = 35, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + if tokenizer is not None: + self.tokenizer = TOKENIZER.build(tokenizer) + self.vision_backbone = MODELS.build(vision_backbone) + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.max_txt_len = max_txt_len + + # For simplity, directly use head definition here. + # If more complex head is designed, move this and loss to a new + # head module. + hidden_size = self.multimodal_backbone.config.hidden_size + self.head = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, 2), + ) + + @property + def device(self): + return next(self.parameters()).device + + def preprocess_text(self, data_samples): + + sample_item = data_samples[0] + + if sample_item is not None and 'text' in sample_item: + texts = [sample.get('text') for sample in data_samples] + else: + return None + + # perform tokenize first if satisfied conditions + texts = self.tokenizer( + texts, + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors='pt', + ).to(self.device) + + return texts + + def forward( + self, + images: dict, + data_samples: Optional[List] = None, + mode: str = 'tensor', + ): + """The unified entry for a forward process in both training and test. + The method should accept only one mode "loss": + + - "loss": Forward and return a dict of losses according to the given + images and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (dict of torch.Tensor): + img: pre_processed img tensor (N, C, ...). + text: tokenized text (N, L) + data_samples (List[CaptionDataSample], optional): + The annotation data of every samples. + 'image': raw image data + 'text' tokenized text + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + # B, T, C, H, W to T*B, C, H, W + images = images.permute(1, 0, 2, 3, 4).flatten(0, 1) + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict(self, images, data_samples=None): + """Predict caption.""" + # prepare inputs for decoder generation. + image_embeds = self.vision_backbone(images)[0] + texts = self.preprocess_text(data_samples) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + image0_embeds, image1_embeds = torch.split(image_embeds, + texts.input_ids.size(0)) + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + texts.input_ids, + attention_mask=texts.attention_mask, + encoder_hidden_states=[image0_embeds, image1_embeds], + encoder_attention_mask=[ + image_atts[:image0_embeds.size(0)], + image_atts[image0_embeds.size(0):], + ], + return_dict=True, + ) + + # get prediction + outputs = self.head(multimodal_embeds.last_hidden_state[:, 0, :]) + + pred_scores = F.softmax(outputs, dim=1) + + for pred_score, data_sample in zip(pred_scores, data_samples): + data_sample.set_pred_score(pred_score) + data_sample.set_pred_label(pred_score.argmax(dim=0)) + + return data_samples + + def loss(self, images, data_samples): + """Calculate losses from a batch of inputs and data samples. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[ImageTextDataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # prepare inputs for decoder generation. + image_embeds = self.vision_backbone(images)[0] + texts = self.preprocess_text(data_samples) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + image0_embeds, image1_embeds = torch.split(image_embeds, + texts.input_ids.size(0)) + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + texts.input_ids, + attention_mask=texts.attention_mask, + encoder_hidden_states=[image0_embeds, image1_embeds], + encoder_attention_mask=[ + image_atts[:image0_embeds.size(0)], + image_atts[image0_embeds.size(0):], + ], + return_dict=True, + ) + + # get prediction + outputs = self.head(multimodal_embeds.last_hidden_state[:, 0, :]) + + targets = torch.tensor([i.gt_label + for i in data_samples]).to(outputs.device) + loss = F.cross_entropy(outputs, targets) + return {'loss': loss} diff --git a/mmpretrain/models/multimodal/blip/blip_retrieval.py b/mmpretrain/models/multimodal/blip/blip_retrieval.py new file mode 100644 index 0000000..3ebc251 --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_retrieval.py @@ -0,0 +1,716 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import ChainMap +from copy import deepcopy +from typing import Dict, List, Optional, Tuple, Union + +import mmengine.dist as dist +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModel +from torch import distributed as torch_dist + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process + + +def all_gather_concat(data: torch.Tensor) -> torch.Tensor: + """Gather tensors with different first-dimension size and concat to one + tenosr. + + Note: + Only the first dimension should be different. + + Args: + data (Tensor): Tensor to be gathered. + + Returns: + torch.Tensor: The concatenated tenosr. + """ + if dist.get_world_size() == 1: + return data + + data_size = torch.tensor(data.size(0), device=data.device) + sizes_list = dist.all_gather(data_size) + + max_length = max(sizes_list) + size_diff = max_length.item() - data_size.item() + if size_diff: + padding = torch.zeros( + size_diff, *data.size()[1:], device=data.device, dtype=data.dtype) + data = torch.cat((data, padding)) + + gather_list = dist.all_gather(data) + + all_data = [] + for tensor, size in zip(gather_list, sizes_list): + + all_data.append(tensor[:size]) + + return torch.concat(all_data) + + +@MODELS.register_module() +class BlipRetrieval(BaseModel): + """BLIP Retriever. + + Args: + vision_backbone (dict): Backbone for extracting image features. + text_backbone (dict): Backbone for extracting text features. + multimodal_backbone (Optional[dict]): Backbone for extracting + multi-modal features. + vision_neck (Optional[dict]): The neck module to process image features + from vision backbone. Defaults to None. + text_neck (Optional[dict]): The neck module to process text features + from text backbone. Defaults to None. + head (Optional[Union[List[dict], dict]]): The head module to calculate + loss from processed single modality features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + multimodal_head (Optional[Union[List[dict], dict]]): The multi-modal + head module to calculate loss from processed multimodal features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + momentum (float): Momentum used for momentum contrast. + Defaults to .995. + negative_all_rank (bool): Whether to sample negative data from all + ranks for image text matching in training. Defaults to True. + temperature (float): Temperature parameter that controls the + concentration level of the distribution. Defaults to 0.07. + fast_match (bool): If False, select topk similarity as candidates and + compute the matching score. If True, return the similarity as the + matching score directly. Defaults to False. + topk (int): Select topk similarity as candidates for compute matching + scores. Notice that this is not the topk in evaluation. + Defaults to 256. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + text_backbone: dict, + multimodal_backbone: Optional[dict] = None, + vision_neck: Optional[dict] = None, + text_neck: Optional[dict] = None, + head: Optional[Union[List[dict], dict]] = None, + multimodal_head: Optional[Union[List[dict], dict]] = None, + tokenizer: Optional[dict] = None, + momentum: float = .995, + negative_all_rank: bool = True, + temperature: float = 0.07, + fast_match: bool = False, + topk: int = 256, + max_txt_len: int = 20, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.vision_backbone = MODELS.build(vision_backbone) + self.text_backbone = MODELS.build(text_backbone) + + if multimodal_backbone is not None: + self.multimodal_backbone = MODELS.build(multimodal_backbone) + + if vision_neck is not None: + self.vision_neck = MODELS.build(vision_neck) + + if text_neck is not None: + self.text_neck = MODELS.build(text_neck) + + if head is not None: + self.head = MODELS.build(head) + + if multimodal_head is not None: + self.multimodal_head = MODELS.build(multimodal_head) + + if tokenizer is not None: + self.tokenizer = TOKENIZER.build(tokenizer) + + self.momentum = momentum + self.negative_all_rank = negative_all_rank + self.temp = nn.Parameter(temperature * torch.ones([])) + # Shares the same para + self.head.temp = self.temp + + # create the momentum encoder + self.vision_backbone_m = deepcopy(self.vision_backbone) + self.text_backbone_m = deepcopy(self.text_backbone) + + self.vision_neck_m = deepcopy(self.vision_neck) + self.text_neck_m = deepcopy(self.text_neck) + + self.model_pairs = [ + [self.vision_backbone, self.vision_backbone_m], + [self.text_backbone, self.text_backbone_m], + [self.vision_neck, self.vision_neck_m], + [self.text_neck, self.text_neck_m], + ] + self.copy_params() + + # multimodal backbone shares weights with text backbone in BLIP + # No need to set up + + # Notice that this topk is used for select k candidate to compute + # image-text score, but not the final metric topk in evaluation. + self.fast_match = fast_match + self.topk = topk + + self.max_txt_len = max_txt_len + + @property + def device(self): + return next(self.parameters()).device + + def preprocess_text(self, data_samples): + sample_item = data_samples[0] + + if sample_item is not None and 'text' in sample_item: + if isinstance(sample_item.get('text'), (list, tuple)): + texts = [] + for sample in data_samples: + texts.extend(sample.get('text')) + elif isinstance(sample_item.get('text'), str): + texts = [sample.get('text') for sample in data_samples] + else: + raise TypeError('text must be a string or a list of strings') + else: + return None + + # perform tokenize first if satisfied conditions + texts = self.tokenizer( + texts, + padding='max_length', + truncation=True, + max_length=self.max_txt_len, + return_tensors='pt', + ).to(self.device) + + return texts + + def forward(self, + images: torch.tensor = None, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor') -> Union[Tuple, dict]: + """The unified entry for a forward process in both training and test. + The method should accept two modes: "tensor", and "loss": + + - "tensor": Forward the whole network and return tensor without any + post-processing, same as a common nn.Module. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + For unified "predict" mode in other mm repos. It is noticed that + image-text retrieval cannot perform batch prediction since it will go + through all the samples. A standard process of retrieval evaluation is + to extract and collect all feats, and then predict all samples. + Therefore the `predict` mode here is remained as a trigger + to inform use to choose the right configurations. + + Args: + images (torch.Tensor): The input inputs tensor of shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + - If ``mode="tensor"``, return a tuple. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + return self.extract_feat(images, data_samples) + elif mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat( + self, + images: torch.Tensor = None, + data_samples: List[DataSample] = None, + return_texts=True, + return_embeds=None, + ) -> Dict[str, torch.Tensor]: + """Extract features from the input dict. + + Args: + images (tensor, optional): The images to extract features. + Defaults to None. + data_samples (list, optional): The data samples containing texts + to extract features. Defaults to None. + return_texts (bool): Whether to return the tokenized text and the + corresponding attention masks. Defaults to True. + return_embeds (bool): Whether to return the text embedding and + image embedding. Defaults to None, which means to use + ``self.fast_match``. + + Returns: + Tuple[torch.Tensor]: The output features. + If multimodal_backbone is not exist, tuple of torch.Tensor + will be returned. + """ + if data_samples is not None: + texts = self.preprocess_text(data_samples) + else: + texts = None + + assert images is not None or texts is not None, \ + 'At least single modality should be passed as inputs.' + + results = {} + if texts is not None and return_texts: + results.update({ + 'text_ids': texts.input_ids, + 'text_attn_mask': texts.attention_mask, + }) + + if return_embeds is None: + return_embeds = not self.fast_match + + # extract image features + if images is not None: + output = self._extract_feat(images, modality='images') + results['image_feat'] = output['image_feat'] + if return_embeds: + results['image_embeds'] = output['image_embeds'] + + # extract text features + if texts is not None: + output = self._extract_feat(texts, modality='texts') + results['text_feat'] = output['text_feat'] + if return_embeds: + results['text_embeds'] = output['text_embeds'] + + return results + + def _extract_feat(self, inputs: Union[torch.Tensor, dict], + modality: str) -> Tuple[torch.Tensor]: + """Extract features from the single modality. + + Args: + inputs (Union[torch.Tensor, dict]): A batch of inputs. + For image, a tensor of shape (N, C, ...) in general. + For text, a dict of tokenized text inputs. + modality (str): Modality feature to be extracted. Only two + options are supported. + + - ``images``: Only extract image features, mostly used for + inference. + - ``texts``: Only extract text features, mostly used for + inference. + + Returns: + Tuple[torch.Tensor]: The output features. + """ + + if modality == 'images': + # extract image features + image_embeds = self.vision_backbone(inputs)[0] + image_feat = F.normalize( + self.vision_neck(image_embeds[:, 0, :]), dim=-1) + return {'image_embeds': image_embeds, 'image_feat': image_feat} + elif modality == 'texts': + # extract text features + text_output = self.text_backbone( + inputs.input_ids, + attention_mask=inputs.attention_mask, + token_type_ids=None, + return_dict=True, + mode='text', + ) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize( + self.text_neck(text_embeds[:, 0, :]), dim=-1) + return {'text_embeds': text_embeds, 'text_feat': text_feat} + else: + raise RuntimeError(f'Invalid modality "{modality}".') + + def loss( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ) -> Dict[str, torch.tensor]: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (dict): A batch of inputs. The input tensor with of + at least one modality. For image, the value is a tensor + of shape (N, C, ...) in general. + For text, the value is a dict of tokenized text inputs. + data_samples (Optional[List[DataSample]]): + The annotation data of every samples. Defaults to None. + + Returns: + Dict[str, torch.tensor]: a dictionary of loss components of + both head and multimodal head. + """ + output = self.extract_feat(images, data_samples, return_embeds=True) + + text_ids = output['text_ids'] + text_attn_mask = output['text_attn_mask'] + image_embeds = output['image_embeds'] + image_feat = output['image_feat'] + text_feat = output['text_feat'] + + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + # get momentum features + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.vision_backbone_m(images)[0] + image_feat_m = F.normalize( + self.vision_neck_m(image_embeds_m[:, 0, :]), dim=-1) + + text_output_m = self.text_backbone_m( + text_ids, + attention_mask=text_attn_mask, + token_type_ids=None, + return_dict=True, + mode='text', + ) + text_embeds_m = text_output_m.last_hidden_state + text_feat_m = F.normalize( + self.text_neck_m(text_embeds_m[:, 0, :]), dim=-1) + + loss = self.head.loss( + ([image_feat, text_feat, image_feat_m, text_feat_m], ), + data_samples) + + # prepare for itm + encoder_input_ids = text_ids.clone() + encoder_input_ids[:, + 0] = self.tokenizer.additional_special_tokens_ids[0] + output_pos = self.text_backbone( + encoder_input_ids, + attention_mask=text_attn_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + idx = torch.tensor([i.image_id for i in data_samples]).view(-1, 1) + bs = idx.size(0) + idxs = torch.cat(dist.all_gather(idx)) + if self.negative_all_rank: + # compute sample similarity + with torch.no_grad(): + mask = torch.eq(idx, idxs.t()).to(self.device) + + image_feat_world = torch.cat(dist.all_gather(image_feat)) + text_feat_world = torch.cat(dist.all_gather(text_feat)) + + sim_i2t = image_feat @ text_feat_world.t() / self.temp + sim_t2i = text_feat @ image_feat_world.t() / self.temp + + weights_i2t = F.softmax(sim_i2t, dim=1) + weights_i2t.masked_fill_(mask, 0) + + weights_t2i = F.softmax(sim_t2i, dim=1) + weights_t2i.masked_fill_(mask, 0) + + world_size = dist.get_world_size() + if world_size == 1: + image_embeds_world = image_embeds + else: + image_embeds_world = torch.cat( + torch_dist.nn.all_gather(image_embeds)) + + # select a negative image (from all ranks) for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds_world[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text (from all ranks) for each image + input_ids_world = torch.cat(dist.all_gather(encoder_input_ids)) + att_mask_world = torch.cat(dist.all_gather(text_attn_mask)) + + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(input_ids_world[neg_idx]) + text_atts_neg.append(att_mask_world[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_ids_all = torch.cat([encoder_input_ids, text_ids_neg], dim=0) + text_atts_all = torch.cat([text_attn_mask, text_atts_neg], dim=0) + + image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) + image_atts_all = torch.cat([image_atts, image_atts], dim=0) + + output_neg = self.text_backbone( + text_ids_all, + attention_mask=text_atts_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + ) + + vl_embeddings = torch.cat( + [ + output_pos.last_hidden_state[:, 0, :], + output_neg.last_hidden_state[:, 0, :], + ], + dim=0, + ) + + # create false data samples + data_samples.extend( + [DataSample(is_matched=False) for _ in range(2 * bs)]) + loss_multimodal = self.multimodal_head.loss((vl_embeddings, ), + data_samples) + + return dict(ChainMap(loss, loss_multimodal)) + + def predict(self, images, data_samples, cal_i2t=True, cal_t2i=True): + feats = self.extract_feat(images, data_samples) + + return self.predict_all( + feats, data_samples, cal_i2t=cal_i2t, cal_t2i=cal_t2i) + + def predict_all(self, + feats, + data_samples, + num_images=None, + num_texts=None, + cal_i2t=True, + cal_t2i=True): + text_ids = feats['text_ids'] + text_ids[:, 0] = self.tokenizer.additional_special_tokens_ids[0] + text_attn_mask = feats['text_attn_mask'] + image_embeds = feats.get('image_embeds', None) + image_feat = feats['image_feat'] + text_feat = feats['text_feat'] + + num_images = num_images or image_feat.size(0) + num_texts = num_texts or text_feat.size(0) + + if not self.fast_match: + image_embeds_all = all_gather_concat(image_embeds)[:num_images] + else: + image_embeds_all = None + image_feat_all = all_gather_concat(image_feat)[:num_images] + text_feat_all = all_gather_concat(text_feat)[:num_texts] + text_ids_all = all_gather_concat(text_ids)[:num_texts] + text_attn_mask_all = all_gather_concat(text_attn_mask)[:num_texts] + + results = [] + if cal_i2t: + result_i2t = self.compute_score_matrix_i2t( + image_feat, + image_embeds, + text_feat_all, + text_ids_all, + text_attn_mask_all, + ) + results.append( + self._get_predictions(result_i2t, data_samples, mode='i2t')) + if cal_t2i: + result_t2i = self.compute_score_matrix_t2i( + image_feat_all, + image_embeds_all, + text_feat, + text_ids, + text_attn_mask, + ) + results.append( + self._get_predictions(result_t2i, data_samples, mode='t2i')) + return tuple(results) + + def compute_score_matrix_i2t(self, img_feats, img_embeds, text_feats, + text_ids, text_atts): + """Compare the score matrix for image-to-text retrieval. Every image + should compare to all the text features. + + Args: + img_feats (torch.Tensor): The input img feats tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + img_embeds (torch.Tensor): The input img embeds tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + text_feats (torch.Tensor): The input text feats tensor with shape + (N, C). N stands for numbers of all samples on all GPUs. + text_ids (torch.Tensor): The input tensor with shape (N, C). + text_atts (torch.Tensor): The input tensor with shape (N, C). + + Returns: + torch.Tensor: Score matrix of image-to-text retrieval. + """ + + # compute i2t sim matrix + sim_matrix_i2t = img_feats @ text_feats.t() + if self.fast_match: + return sim_matrix_i2t + + score_matrix_i2t = torch.full((img_feats.size(0), text_feats.size(0)), + -100.0).to(self.device) + for i in track_on_main_process( + range(img_feats.size(0)), 'Compute I2T scores...'): + sims = sim_matrix_i2t[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + + encoder_output = img_embeds[i].repeat(self.topk, 1, 1) + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + output = self.text_backbone( + text_ids[topk_idx], + attention_mask=text_atts[topk_idx], + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, 0, :], ))[:, 1] + score_matrix_i2t[i, topk_idx] = score + topk_sim + + return score_matrix_i2t + + def compute_score_matrix_t2i(self, img_feats, img_embeds, text_feats, + text_ids, text_atts): + """Compare the score matrix for text-to-image retrieval. Every text + should compare to all the image features. + + Args: + img_feats (torch.Tensor): The input img feats tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + img_embeds (torch.Tensor): The input img embeds tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + text_feats (torch.Tensor): The input text feats tensor with shape + (N, C). N stands for numbers of all samples on all GPUs. + text_ids (torch.Tensor): The input tensor with shape (M, C). + text_atts (torch.Tensor): The input tensor with shape (M, C). + + Returns: + torch.Tensor: Score matrix of text-to-image retrieval. + """ + + # compute t2i sim matrix + sim_matrix_t2i = text_feats @ img_feats.t() + if self.fast_match: + return sim_matrix_t2i + + score_matrix_t2i = torch.full((text_feats.size(0), img_feats.size(0)), + -100.0).to(self.device) + for i in track_on_main_process( + range(text_feats.size(0)), 'Compute T2I scores...'): + sims = sim_matrix_t2i[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + + encoder_output = img_embeds[topk_idx] + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + output = self.text_backbone( + text_ids[i].repeat(self.topk, 1), + attention_mask=text_atts[i].repeat(self.topk, 1), + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, 0, :], ))[:, 1] + score_matrix_t2i[i, topk_idx] = score + topk_sim + + return score_matrix_t2i + + def _get_predictions(self, + result: torch.Tensor, + data_samples: List[DataSample], + mode: str = 'i2t'): + """Post-process the output of retriever. + + Args: + result (torch.Tensor): Score matrix of single retrieve, + either from image or text. + data_samples (List[DataSample], optional): The annotation + data of every samples. + mode (str): Retrieve mode, either `i2t` for image to text, or `t2i` + text to image. Defaults to `i2t`. + + Returns: + List[DataSample]: the raw data_samples with + the predicted results. + """ + + # create data sample if not exists + if data_samples is None: + data_samples = [DataSample() for _ in range(result.size(0))] + elif mode == 't2i': + # Process data samples to align with the num of texts. + new_data_samples = [] + for sample in data_samples: + if isinstance(sample.text, (list, tuple)): + texts = sample.text + else: + texts = [sample.text] + for i, text in enumerate(texts): + new_sample = DataSample(text=text) + if 'gt_image_id' in sample: + new_sample.gt_label = sample.gt_image_id[i] + new_data_samples.append(new_sample) + assert len(new_data_samples) == result.size(0) + data_samples = new_data_samples + elif mode == 'i2t': + for sample in data_samples: + if 'gt_text_id' in sample: + sample.gt_label = sample.gt_text_id + else: + raise ValueError(f'Type {mode} is not supported.') + + for data_sample, score in zip(data_samples, result): + idx = score.argmax(keepdim=True).detach() + + data_sample.set_pred_score(score) + data_sample.set_pred_label(idx) + return data_samples + + # TODO: add temperaily + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), + model_pair[1].parameters()): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for (name, + param), (name_m, + param_m) in zip(model_pair[0].named_parameters(), + model_pair[1].named_parameters()): + # hack to behave the same + if any([i in name for i in ['8', '9', '10', '11'] + ]) and 'layers' in name and any( + [i in name for i in ['attn', 'ffn']]): + param_m.data = param.data + else: + param_m.data = param_m.data * self.momentum + \ + param.data * (1.0 - self.momentum) diff --git a/mmpretrain/models/multimodal/blip/blip_vqa.py b/mmpretrain/models/multimodal/blip/blip_vqa.py new file mode 100644 index 0000000..d0f4e58 --- /dev/null +++ b/mmpretrain/models/multimodal/blip/blip_vqa.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class BlipVQA(BaseModel): + """BLIP VQA. + + Args: + tokenizer: (dict): The config for tokenizer. + vision_backbone (dict): Encoder for extracting image features. + multimodal_backbone (dict): Backbone for extracting + multi-modal features. We apply this part as VQA fusion module. + head (dict): The head module to calculate + loss from processed features. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + `MutimodalDataPreprocessor` as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + multimodal_backbone: dict, + head: dict, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(BlipVQA, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.vision_backbone = MODELS.build(vision_backbone) + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.vqa_head = MODELS.build(head) + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + + - "loss": For training. Forward and return a dict of losses according + to the given inputs and data samples. Note that this method doesn't + handle neither back propagation nor optimizer updating, which are + done in the :meth:`train_step`. + - "predict": For testing. Forward and return a list of data_sample that + contains pred_answer for each question. + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + data_samples (List[DataSample], optional): The annotation data of + every samples. Required when ``mode="loss"``. Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + - If ``mode="predict"``, return a list of `DataSample` + """ + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, images: torch.Tensor) -> torch.Tensor: + """Extract features from the input tensor with shape (N, C, ..). + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + + Returns: + visual_embeds (Tensor): The output features. + """ + # extract visual feature + if images.ndim == 4: + visual_embeds = self.vision_backbone(images)[0] + elif images.ndim == 5: + # [batch, T, C, H, W] -> [batch * T, C, H, W] + bs = images.size(0) + images = images.reshape(-1, *images.shape[2:]) + visual_embeds = self.vision_backbone(images)[0] + # [batch * num_segs, L, dim] -> [batch, num_segs * L, dim] + visual_embeds = visual_embeds.reshape(bs, -1, + *visual_embeds.shape[2:]) + else: + raise ValueError( + f'Images with {images.ndim} dims is not supported.') + return visual_embeds + + def loss( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """generate train_loss from the input tensor and data_samples. + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + Dict[torch.Tensor]: The losses features. + """ + visual_embeds = self.extract_feat(images) + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + questions = [] + for sample in data_samples: + questions.append(sample.get('question')) + questions = self.tokenizer( + questions, padding='longest', return_tensors='pt').to(self.device) + + questions.input_ids[:, 0] = \ + self.tokenizer.additional_special_tokens_ids[0] + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + questions.input_ids, + attention_mask=questions.attention_mask, + encoder_hidden_states=visual_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + # put answer from data_samples into tensor form + answer_raw_text = [] + for sample in data_samples: + answer_raw_text.extend(sample.gt_answer) + answer = self.tokenizer( + answer_raw_text, padding='longest', + return_tensors='pt').to(self.device) + answer_targets = answer.input_ids.masked_fill( + answer.input_ids == self.tokenizer.pad_token_id, -100) + for sample in data_samples: + # follow BLIP setting, set answer_weight to 0.2 for VG dataset. + if not hasattr(sample, 'gt_answer_weight'): + sample.gt_answer_weight = torch.tensor([0.2]) + else: + sample.gt_answer_weight = torch.tensor(sample.gt_answer_weight) + answer_weight = torch.cat( + [sample.gt_answer_weight for sample in data_samples], + dim=0).to(self.device) + answer_count = torch.tensor( + [len(sample.gt_answer) for sample in data_samples]).to(self.device) + + question_states, question_atts = [], [] + for b, n in enumerate(answer_count): + question_states += [multimodal_embeds.last_hidden_state[b]] * n + question_atts += [questions.attention_mask[b]] * n + + question_states = torch.stack(question_states, dim=0).to(self.device) + question_atts = torch.stack(question_atts, dim=0).to(self.device) + + head_feats = dict( + answer_input_ids=answer.input_ids, + answer_attention_mask=answer.attention_mask, + answer_weight=answer_weight, + answer_targets=answer_targets, + question_states=question_states, + question_atts=question_atts, + batch_size=len(data_samples), + ) + + losses = self.vqa_head.loss(head_feats) + + return losses + + def predict( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ): + """update data_samples that contain pred_answer for each question. + + Args: + images (Tensor): A batch of images. The shape of it should be + (B, C, H, W) for images and (B, T, C, H, W) for videos. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + Dict[torch.Tensor]: The losses features. + """ + visual_embeds = self.extract_feat(images) + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + questions = [] + for sample in data_samples: + questions.append(sample.get('question')) + questions = self.tokenizer( + questions, padding='longest', return_tensors='pt').to(self.device) + + questions.input_ids[:, 0] = \ + self.tokenizer.additional_special_tokens_ids[0] + + # multimodal fusion + multimodal_embeds = self.multimodal_backbone( + questions.input_ids, + attention_mask=questions.attention_mask, + encoder_hidden_states=visual_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + if self.vqa_head.inference_method == 'rank': + answer_candidates = self.tokenizer( + self.vqa_head.answer_list, + padding='longest', + return_tensors='pt').to(self.device) + answer_candidates.input_ids[:, 0] = self.tokenizer.bos_token_id + elif self.vqa_head.inference_method == 'generate': + answer_candidates = None + + head_feats = dict( + multimodal_embeds=multimodal_embeds.last_hidden_state, + question_atts=questions.attention_mask, + answer_candidates=answer_candidates, + bos_token_id=self.tokenizer.bos_token_id, + sep_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + + if self.vqa_head.inference_method == 'rank': + answers = self.vqa_head.predict(head_feats) + for answer, data_sample in zip(answers, data_samples): + data_sample.pred_answer = answer + + elif self.vqa_head.inference_method == 'generate': + outputs = self.vqa_head.predict(head_feats) + for output, data_sample in zip(outputs, data_samples): + data_sample.pred_answer = self.tokenizer.decode( + output, skip_special_tokens=True) + + return data_samples diff --git a/mmpretrain/models/multimodal/blip/language_model.py b/mmpretrain/models/multimodal/blip/language_model.py new file mode 100644 index 0000000..48605a9 --- /dev/null +++ b/mmpretrain/models/multimodal/blip/language_model.py @@ -0,0 +1,1320 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# flake8: noqa + +import math +from typing import Tuple + +import torch +import torch.nn as nn +from torch import Tensor, device + +try: + from transformers.activations import ACT2FN + from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) + from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) + from transformers.models.bert.configuration_bert import BertConfig +except: + ACT2FN = None + BaseModelOutputWithPastAndCrossAttentions = None + BaseModelOutputWithPoolingAndCrossAttentions = None + CausalLMOutputWithCrossAttentions = None + PreTrainedModel = None + apply_chunking_to_forward = None + find_pruneable_heads_and_indices = None + prune_linear_layer = None + BertConfig = None + +from mmpretrain.registry import MODELS + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + if config.add_type_embeddings: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward( + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if token_type_ids is not None: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + else: + embeddings = inputs_embeds + + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and a simple + interface for downloading and loading pretrained models.""" + + config_class = BertConfig + base_model_prefix = 'bert' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, 'embedding_size'): + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = ( + attention_scores + relative_position_scores_query + + relative_position_scores_key) + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ((context_layer, attention_probs) if output_attentions else + (context_layer, )) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config, twin=False, merge=False): + super().__init__() + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if twin: + self.dense0 = nn.Linear(config.hidden_size, config.hidden_size) + self.dense1 = nn.Linear(config.hidden_size, config.hidden_size) + else: + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if merge: + self.act = ACT2FN[config.hidden_act] + self.merge_layer = nn.Linear(config.hidden_size * 2, + config.hidden_size) + self.merge = True + else: + self.merge = False + + def forward(self, hidden_states, input_tensor): + if type(hidden_states) == list: + hidden_states0 = self.dense0(hidden_states[0]) + hidden_states1 = self.dense1(hidden_states[1]) + if self.merge: + hidden_states = self.merge_layer( + torch.cat([hidden_states0, hidden_states1], dim=-1)) + else: + hidden_states = (hidden_states0 + hidden_states1) / 2 + else: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, is_cross_attention=False, layer_num=-1): + super().__init__() + is_nlvr = is_cross_attention and getattr(config, 'nlvr', False) + if is_nlvr: + self.self0 = BertSelfAttention(config, is_nlvr) + self.self1 = BertSelfAttention(config, is_nlvr) + else: + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput( + config, + twin=is_nlvr, + merge=(is_nlvr and layer_num >= 6), + ) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + if type(encoder_hidden_states) == list: + self_outputs0 = self.self0( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[0], + encoder_attention_mask[0], + past_key_value, + output_attentions, + ) + self_outputs1 = self.self1( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[1], + encoder_attention_mask[1], + past_key_value, + output_attentions, + ) + attention_output = self.output( + [self_outputs0[0], self_outputs1[0]], hidden_states) + + outputs = (attention_output, ) + self_outputs0[ + 1:] # add attentions if we output them + else: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + + # compatibility for ALBEF and BLIP + try: + # ALBEF & ALPRO + fusion_layer = self.config.fusion_layer + add_cross_attention = ( + fusion_layer <= layer_num and self.config.add_cross_attention) + + self.fusion_layer = fusion_layer + except AttributeError: + # BLIP + self.fusion_layer = self.config.num_hidden_layers + add_cross_attention = self.config.add_cross_attention + + # if self.config.add_cross_attention: + if self.config.add_cross_attention: + self.crossattention = BertAttention( + config, + is_cross_attention=self.config.add_cross_attention, + layer_num=layer_num, + ) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + # TODO line 482 in albef/models/xbert.py + # compatibility for ALBEF and BLIP + if mode in ['multimodal', 'fusion'] and hasattr( + self, 'crossattention'): + assert ( + encoder_hidden_states is not None + ), 'encoder_hidden_states must be given for cross-attention layers' + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = (outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = (() if output_attentions + and self.config.add_cross_attention else None) + + next_decoder_cache = () if use_cache else None + + try: + # ALBEF + fusion_layer = self.config.fusion_layer + except AttributeError: + # BLIP + fusion_layer = self.config.num_hidden_layers + + if mode == 'text': + start_layer = 0 + # output_layer = self.config.fusion_layer + output_layer = fusion_layer + + elif mode == 'fusion': + # start_layer = self.config.fusion_layer + start_layer = fusion_layer + output_layer = self.config.num_hidden_layers + + elif mode == 'multimodal': + start_layer = 0 + output_layer = self.config.num_hidden_layers + + # compatibility for ALBEF and BLIP + # for i in range(self.config.num_hidden_layers): + for i in range(start_layer, output_layer): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + # TODO pay attention to this. + if self.gradient_checkpointing and self.training: + + if use_cache: + # TODO: logger here + # logger.warn( + # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + # ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@MODELS.register_module() +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need `__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=True): + if not isinstance(config, BertConfig): + config = BertConfig.from_dict(config) + + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + ) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= + seq_ids[None, :, None]) + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:, None, :, :] * + attention_mask[:, None, None, :]) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})' + .format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds or encoder_embeds' + ) + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] + if past_key_values is not None else 0) + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BaseEncoder(nn.Module): + """Base class for primitive encoders, such as ViT, TimeSformer, etc.""" + + def __init__(self): + super().__init__() + + def forward_features(self, samples, **kwargs): + raise NotImplementedError + + @property + def device(self): + return list(self.parameters())[0].device + + +@MODELS.register_module() +class XBertEncoder(BertModel, BaseEncoder): + + def __init__(self, med_config, from_pretrained=False): + + med_config = BertConfig.from_dict(med_config) + super().__init__(config=med_config, add_pooling_layer=False) + + def forward_automask(self, tokenized_text, visual_embeds, **kwargs): + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + text = tokenized_text + text_output = super().forward( + text.input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=visual_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + return text_output + + def forward_text(self, tokenized_text, **kwargs): + text = tokenized_text + token_type_ids = kwargs.get('token_type_ids', None) + + text_output = super().forward( + text.input_ids, + attention_mask=text.attention_mask, + token_type_ids=token_type_ids, + return_dict=True, + mode='text', + ) + + return text_output + + +@MODELS.register_module() +class Linear(torch.nn.Linear): + """Wrapper for linear function.""" + + +@MODELS.register_module() +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, + BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained( + 'bert-base-cased') + >>> config = BertConfig.from_pretrained( + "bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained( + 'bert-base-cased', config=config) + >>> inputs = tokenizer( + "Hello, my dog is cute", + return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, + # the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + +@MODELS.register_module() +class XBertLMHeadDecoder(BertLMHeadModel): + """This class decouples the decoder forward logic from the VL model. + + In this way, different VL models can share this decoder as long as they + feed encoder_embeds as required. + """ + + def __init__(self, med_config): + self.med_config = BertConfig.from_dict(med_config) + super(XBertLMHeadDecoder, self).__init__(config=self.med_config) + + def generate_from_encoder(self, + tokenized_prompt, + visual_embeds, + sep_token_id, + pad_token_id, + use_nucleus_sampling=False, + num_beams=3, + max_length=30, + min_length=10, + top_p=0.9, + repetition_penalty=1.0, + **kwargs): + + if not use_nucleus_sampling: + num_beams = num_beams + visual_embeds = visual_embeds.repeat_interleave(num_beams, dim=0) + + image_atts = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long).to(self.device) + + model_kwargs = { + 'encoder_hidden_states': visual_embeds, + 'encoder_attention_mask': image_atts, + } + + if use_nucleus_sampling: + # nucleus sampling + outputs = self.generate( + input_ids=tokenized_prompt.input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + # beam search + outputs = self.generate( + input_ids=tokenized_prompt.input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + return outputs diff --git a/mmpretrain/models/multimodal/blip2/Qformer.py b/mmpretrain/models/multimodal/blip2/Qformer.py new file mode 100644 index 0000000..4b1c7d1 --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/Qformer.py @@ -0,0 +1,773 @@ +# flake8: noqa +""" + * Copyright (c) 2023, salesforce.com, inc. +""" +from typing import Tuple + +import torch +import torch.utils.checkpoint +from torch import Tensor, device, nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) +from transformers.modeling_utils import apply_chunking_to_forward +from transformers.models.bert.configuration_bert import BertConfig +from transformers.utils import logging + +from mmpretrain.registry import MODELS +from ..blip.language_model import (BertAttention, BertIntermediate, + BertOnlyMLMHead, BertOutput, BertPooler, + BertPreTrainedModel) + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if (self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), 'encoder_hidden_states must be given for cross-attention layers' + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], + dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = (() if output_attentions + and self.config.add_cross_attention else None) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if getattr(self.config, 'gradient_checkpointing', + False) and self.training: + + if use_cache: + logger.warn( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions, query_length) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + all_cross_attentions = all_cross_attentions + ( + layer_outputs[2], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need `__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= + seq_ids[None, :, None]) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], + prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * + attention_mask[:, None, None, :]) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})' + .format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + if input_ids is None: + assert ( + query_embeds is not None + ), 'You have to specify query_embeds when input_ids is None' + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - + self.config.query_length if past_key_values is not None else 0) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + if self.cls is not None: + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 + tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1]:, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + query_embeds, + past=None, + attention_mask=None, + **model_kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'query_embeds': + query_embeds, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + +@MODELS.register_module() +class Qformer(BertLMHeadModel): + + def __init__(self, model_style: str, vision_model_width: int, + add_cross_attention: bool, cross_attention_freq: int, + num_query_token: int) -> None: + + config = BertConfig.from_pretrained(model_style) + config.add_cross_attention = add_cross_attention + config.encoder_width = vision_model_width + config.cross_attention_freq = cross_attention_freq + config.query_length = num_query_token + super().__init__(config) diff --git a/mmpretrain/models/multimodal/blip2/__init__.py b/mmpretrain/models/multimodal/blip2/__init__.py new file mode 100644 index 0000000..b5695f2 --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .blip2_caption import Blip2Caption +from .blip2_opt_vqa import Blip2VQA +from .blip2_retriever import Blip2Retrieval +from .modeling_opt import OPTForCausalLM +from .Qformer import Qformer + +__all__ = [ + 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'OPTForCausalLM', 'Qformer' +] diff --git a/mmpretrain/models/multimodal/blip2/blip2_caption.py b/mmpretrain/models/multimodal/blip2/blip2_caption.py new file mode 100644 index 0000000..acf6948 --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/blip2_caption.py @@ -0,0 +1,315 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional + +import torch +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class Blip2Caption(BaseModel): + """BLIP2 Caption. + + Module for BLIP2 Caption task. + + Args: + vision_backbone (dict): The config dict for vision backbone. + text_backbone (dict): The config dict for text backbone. + multimodal_backbone (dict): The config dict for multimodal backbone. + vision_neck (dict): The config dict for vision neck. + tokenizer: (Optional[dict]): The config for tokenizer. + Defaults to None. + prompt (str): Prompt used for training and eval. + Defaults to ''. + max_txt_len (int): Max text length of input text. + num_captions (int): Number of captions to be generated for each image. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + _no_split_modules = ['BEiTViT', 'OPTDecoderLayer', 'BertLayer'] + + def __init__(self, + vision_backbone: dict, + text_backbone: dict, + multimodal_backbone: dict, + vision_neck: dict, + tokenizer: Optional[dict] = None, + prompt: str = '', + max_txt_len: int = 20, + num_captions: int = 1, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.eos_token_id = self.tokenizer( + '\n', add_special_tokens=False).input_ids[0] + + self.vision_backbone = MODELS.build(vision_backbone) + self.ln_vision_backbone = nn.LayerNorm(self.vision_backbone.embed_dims) + + self.vision_neck = MODELS.build(vision_neck) + + self.text_backbone = MODELS.build(text_backbone) + + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.multimodal_backbone.cls = None + self.multimodal_backbone.bert.embeddings.word_embeddings = None + self.multimodal_backbone.bert.embeddings.position_embeddings = None + for layer in self.multimodal_backbone.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + self.prompt = prompt + self.max_txt_len = max_txt_len + self.num_captions = num_captions + prompt_tokens = self.tokenizer(prompt, return_tensors='pt') + self.prompt_length = prompt_tokens.attention_mask.sum(1) + + self.query_tokens = nn.Parameter( + torch.zeros(1, self.multimodal_backbone.bert.config.query_length, + self.multimodal_backbone.bert.config.hidden_size)) + self.query_tokens.data.normal_( + mean=0.0, + std=self.multimodal_backbone.bert.config.initializer_range) + + # freeze the text backbone + for _, param in self.text_backbone.named_parameters(): + param.requires_grad = False + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook( + self._ignore_loading_llm_keys_hook) + + if hasattr(self, '_register_state_dict_hook'): + self._register_state_dict_hook(self._igonre_saving_llm_keys_hook) + + def forward(self, + images: torch.Tensor, + data_samples: Optional[List] = None, + mode: str = 'loss'): + """The unified entry for a forward process in both training and test. + The method should accept two modes: "predict" and "loss": + + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): pre_processed img tensor (N, C, ...). + data_samples (List[DataSample], optional): + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. + """ + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def loss(self, + images: torch.Tensor, + data_samples: Optional[list] = None, + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``loss`` + method of :attr:`head`. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + # extract image features + image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], + dtype=torch.long, + ).to(images.device) + + # distill image features to query tokens + query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1) + query_outputs = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + inputs_opt = self.vision_neck([query_outputs.last_hidden_state]) + attns_opt = torch.ones( + inputs_opt.size()[:-1], dtype=torch.long).to(images.device) + + self.tokenizer.padding_side = 'right' + + prompt = [ + self.prompt + data_sample.gt_caption + '\n' + for data_sample in data_samples + ] + + opt_tokens = self.tokenizer( + prompt, + return_tensors='pt', + padding='longest', + truncation=True, + max_length=self.max_txt_len, + ).to(images.device) + + targets = opt_tokens.input_ids.masked_fill( + opt_tokens.input_ids == self.tokenizer.pad_token_id, -100) + if self.prompt: + targets[:, :self.prompt_length] = -100 + + empty_targets = ( + torch.ones(attns_opt.size(), + dtype=torch.long).to(images.device).fill_(-100)) + targets = torch.cat([empty_targets, targets], dim=1) + + inputs_embeds = ( + self.text_backbone.model.decoder.embed_tokens( + opt_tokens.input_ids)) + inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) + attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask], + dim=1) + + outputs = self.text_backbone( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + + return {'loss': loss} + + def predict(self, + images: torch.Tensor, + data_samples: Optional[list] = None, + **kwargs) -> List[DataSample]: + """Predict captions from a batch of inputs. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + + Returns: + List[DataSample]: Return list of data samples. + """ + + # extract image features + image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], + dtype=torch.long, + ).to(images.device) + + # distill image features to query tokens + query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1) + query_outputs = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + inputs_opt = self.vision_neck([query_outputs.last_hidden_state]) + attns_opt = torch.ones( + inputs_opt.size()[:-1], dtype=torch.long).to(images.device) + + prompt = [self.prompt] * image_embeds.size(0) + + opt_tokens = self.tokenizer( + prompt, + return_tensors='pt', + padding='longest', + truncation=True, + max_length=self.max_txt_len, + ).to(images.device) + attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask], + dim=1) + + inputs_embeds = ( + self.text_backbone.get_input_embeddings()(opt_tokens.input_ids)) + inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) + + outputs = self.text_backbone.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + do_sample=False, + top_p=0.9, + temperature=1., + num_beams=5, + max_new_tokens=self.max_txt_len, + min_length=1, + eos_token_id=self.eos_token_id, + repetition_penalty=1.0, + length_penalty=1.0, + num_return_sequences=self.num_captions, + ) + + output_text = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(len(output_text))] + + for data_sample, decode_token in zip(data_samples, output_text): + if data_sample is None: + data_sample = DataSample() + data_sample.pred_caption = decode_token + out_data_samples.append(data_sample) + + return out_data_samples + + @staticmethod + def _ignore_loading_llm_keys_hook(module, incompatible_keys): + """Avoid warning missing keys of the LLM model.""" + import re + llm_pattern = '^text_backbone' + for key in list(incompatible_keys.missing_keys): + if re.match(llm_pattern, key): + incompatible_keys.missing_keys.remove(key) + + @staticmethod + def _igonre_saving_llm_keys_hook(module, state_dict, prefix, metadata): + """Avoid saving llm state dict.""" + import re + llm_pattern = '^text_backbone' + keys = [k for k, _ in state_dict.items()] + for key in keys: + if re.match(llm_pattern, key): + state_dict.pop(key) diff --git a/mmpretrain/models/multimodal/blip2/blip2_opt_vqa.py b/mmpretrain/models/multimodal/blip2/blip2_opt_vqa.py new file mode 100644 index 0000000..20e439f --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/blip2_opt_vqa.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .blip2_caption import Blip2Caption + + +@MODELS.register_module() +class Blip2VQA(Blip2Caption): + """BLIP2 VQA. + + Module for BLIP2 VQA task. For more details about the initialization + params, please refer to :class:`Blip2Caption`. + """ + + def predict(self, + images: torch.Tensor, + data_samples: Optional[list] = None, + **kwargs) -> List[DataSample]: + """Predict captions from a batch of inputs. + + Args: + images (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + + Returns: + List[DataSample]: Return list of data samples. + """ + questions = [d.question for d in data_samples] + + # extract image features from + image_embeds = self.ln_vision_backbone(self.vision_backbone(images)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], + dtype=torch.long, + ).to(images.device) + + # distill image features to query tokens + query_tokens = self.query_tokens.expand(image_embeds.size(0), -1, -1) + query_outputs = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + inputs_opt = self.vision_neck([query_outputs.last_hidden_state]) + attns_opt = torch.ones( + inputs_opt.size()[:-1], dtype=torch.long).to(images.device) + + prompt = [self.prompt.format(q) for q in questions] + + # use left padding + self.tokenizer.padding_side = 'left' + + opt_tokens = self.tokenizer( + prompt, return_tensors='pt', padding='longest').to(images.device) + input_ids = opt_tokens.input_ids + attention_mask = torch.cat([attns_opt, opt_tokens.attention_mask], + dim=1) + + inputs_embeds = self.text_backbone.model.decoder.embed_tokens( + input_ids) + inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) + + outputs = self.text_backbone.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + do_sample=False, + num_beams=5, + max_new_tokens=self.max_txt_len, + min_length=1, + eos_token_id=self.eos_token_id, + length_penalty=-1.0, + ) + + output_text = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + + out_data_samples = [] + for data_sample, decode_token in zip(data_samples, output_text): + data_sample.pred_answer = decode_token + out_data_samples.append(data_sample) + + return out_data_samples diff --git a/mmpretrain/models/multimodal/blip2/blip2_retriever.py b/mmpretrain/models/multimodal/blip2/blip2_retriever.py new file mode 100644 index 0000000..e626404 --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/blip2_retriever.py @@ -0,0 +1,505 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import mmengine.dist as dist +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.utils import track_iter_progress + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from ..blip.blip_retrieval import BlipRetrieval, all_gather_concat + + +@MODELS.register_module() +class Blip2Retrieval(BlipRetrieval): + """BLIP2 Retriever. + + Args: + vision_backbone (dict): Backbone for extracting image features. + text_backbone (dict): Backbone for extracting text features. + multimodal_backbone (Optional[dict]): Backbone for extracting + multi-modal features. + vision_neck (Optional[dict]): The neck module to process image features + from vision backbone. Defaults to None. + text_neck (Optional[dict]): The neck module to process text features + from text backbone. Defaults to None. + head (Optional[Union[List[dict], dict]]): The head module to calculate + loss from processed single modality features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + multimodal_head (Optional[Union[List[dict], dict]]): The multi-modal + head module to calculate loss from processed multimodal features. + See :mod:`mmmultimodal.models.heads`. + Notice that if the head is not set, `loss` method cannot be used. + Defaults to None. + tokenizer (Optional[dict]): The config for tokenizer. Defaults to None. + temperature (float): Temperature parameter that controls the + concentration level of the distribution. Defaults to 0.07. + fast_match (bool): If False, select topk similarity as candidates and + compute the matching score. If True, return the similarity as the + matching score directly. Defaults to False. + topk (int): Select topk similarity as candidates for compute matching + scores. Notice that this is not the topk in evaluation. + Defaults to 256. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + text_backbone: Optional[dict] = None, + multimodal_backbone: Optional[dict] = None, + vision_neck: Optional[dict] = None, + text_neck: Optional[dict] = None, + head: Optional[Union[List[dict], dict]] = None, + multimodal_head: Optional[Union[List[dict], dict]] = None, + tokenizer: Optional[dict] = None, + temperature: float = 0.07, + fast_match: bool = False, + topk: int = 256, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + # Skip BlipRetrieval init + super(BlipRetrieval, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.vision_backbone = MODELS.build(vision_backbone) + self.ln_vision_backbone = nn.LayerNorm(self.vision_backbone.embed_dims) + self.tokenizer = TOKENIZER.build(tokenizer) + + if text_backbone is not None: + self.text_backbone = MODELS.build(text_backbone) + + if multimodal_backbone is not None: + self.multimodal_backbone = MODELS.build(multimodal_backbone) + self.multimodal_backbone.resize_token_embeddings( + len(self.tokenizer)) + self.query_tokens = nn.Parameter( + torch.zeros(1, self.multimodal_backbone.bert.config.query_length, + self.multimodal_backbone.bert.config.hidden_size)) + self.query_tokens.data.normal_( + mean=0.0, + std=self.multimodal_backbone.bert.config.initializer_range) + + if vision_neck is not None: + self.vision_neck = MODELS.build(vision_neck) + + if text_neck is not None: + self.text_neck = MODELS.build(text_neck) + + if head is not None: + self.head = MODELS.build(head) + + if multimodal_head is not None: + self.multimodal_head = MODELS.build(multimodal_head) + + self.temp = nn.Parameter(temperature * torch.ones([])) + + # Notice that this topk is used for select k candidate to compute + # image-text score, but not the final metric topk in evaluation. + self.fast_match = fast_match + self.topk = topk + + def _extract_feat(self, inputs: Union[torch.Tensor, dict], + modality: str) -> Tuple[torch.Tensor]: + """Extract features from the single modality. + Args: + inputs (Union[torch.Tensor, dict]): A batch of inputs. + For image, a tensor of shape (N, C, ...) in general. + For text, a dict of tokenized text inputs. + modality (str): Modality feature to be extracted. Only two + options are supported. + + - ``images``: Only extract image features, mostly used for + inference. + - ``texts``: Only extract text features, mostly used for + inference. + Returns: + Tuple[torch.Tensor]: The output features. + """ + if modality == 'images': + # extract image features + # TODO: + # Add layernorm inside backbone and handle the concat outside + image_embeds = self.ln_vision_backbone( + self.vision_backbone(inputs)[0]) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, + -1) + query_output = self.multimodal_backbone.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + use_cache=True, + return_dict=True, + ) + image_feat = F.normalize( + self.vision_neck([query_output.last_hidden_state]), dim=-1) + return { + 'image_embeds': image_embeds, + 'image_feat': image_feat, + 'query_output': query_output + } + elif modality == 'texts': + # extract text features + text_output = self.multimodal_backbone.bert( + inputs.input_ids, + attention_mask=inputs.attention_mask, + return_dict=True, + ) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize( + self.text_neck([text_embeds[:, 0, :]]), dim=-1) + return {'text_embeds': text_embeds, 'text_feat': text_feat} + else: + raise RuntimeError(f'Invalid modality "{modality}".') + + def loss( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + ) -> Dict[str, torch.tensor]: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (dict): A batch of inputs. The input tensor with of + at least one modality. For image, the value is a tensor + of shape (N, C, ...) in general. + For text, the value is a dict of tokenized text inputs. + data_samples (Optional[List[DataSample]]): + The annotation data of every samples. Defaults to None. + + Returns: + Dict[str, torch.tensor]: a dictionary of loss components of + both head and multimodal head. + """ + output = self.extract_feat(images, data_samples) + + text_ids = output['text_ids'] + text_attn_mask = output['text_attn_mask'] + image_embeds = output['image_embeds'] + image_feat = output['image_feat'] + text_feat = output['text_feat'] + query_output = output['query_output'] + + # ITC Loss + # B*world_size, num_query, D + image_feat_all = torch.cat(dist.all_gather(image_feat)) + # B*world_size, D + text_feat_all = torch.cat(dist.all_gather(text_feat)) + + # B, B*world_size, num_query + sim_q2t = torch.matmul( + image_feat.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze() + + # image to text similarity + sim_i2t, _ = sim_q2t.max(-1) + sim_i2t = sim_i2t / self.temp + + # B, B*world_size, num_query + sim_t2q = torch.matmul( + text_feat.unsqueeze(1).unsqueeze(1), + image_feat_all.permute(0, 2, 1)).squeeze() + + # text-image similarity + sim_t2i, _ = sim_t2q.max(-1) + sim_t2i = sim_t2i / self.temp + + rank = dist.get_rank() + bs = images.size(0) + targets = torch.linspace( + rank * bs, rank * bs + bs - 1, bs, dtype=int).to(self.device) + + itc_loss = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2 + + # prepare for itm + text_input_ids_world = torch.cat(dist.all_gather(text_ids)) + text_attention_mask_world = torch.cat(dist.all_gather(text_attn_mask)) + image_embeds_world = torch.cat(dist.all_gather(image_embeds)) + with torch.no_grad(): + weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4 + weights_t2i[:, rank * bs:rank * bs + bs].fill_diagonal_(0) + weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4 + weights_i2t[:, rank * bs:rank * bs + bs].fill_diagonal_(0) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds_world[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(text_input_ids_world[neg_idx]) + text_atts_neg.append(text_attention_mask_world[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_ids_all = torch.cat([text_ids, text_ids, text_ids_neg], + dim=0) # pos, pos, neg + text_atts_all = torch.cat( + [text_attn_mask, text_attn_mask, text_atts_neg], + dim=0, + ) + + query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, + -1) + query_atts_itm = torch.ones( + query_tokens_itm.size()[:-1], dtype=torch.long).to(self.device) + attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) + + image_embeds_all = torch.cat( + [image_embeds, image_embeds_neg, image_embeds], + dim=0) # pos, neg, pos + image_atts_all = torch.ones( + image_embeds_all.size()[:-1], dtype=torch.long).to(self.device) + + output_itm = self.multimodal_backbone.bert( + text_ids_all, + query_embeds=query_tokens_itm, + attention_mask=attention_mask_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + ) + + vl_embeddings = output_itm.last_hidden_state[:, :query_tokens_itm. + size(1), :] + + # create false data samples + data_samples.extend( + [DataSample(is_matched=False) for _ in range(2 * bs)]) + loss_multimodal = self.multimodal_head.loss((vl_embeddings, ), + data_samples) + + # LM loss + decoder_input_ids = text_ids.clone() + decoder_input_ids[:, 0] = self.tokenizer.bos_token_id + labels = decoder_input_ids.masked_fill( + decoder_input_ids == self.tokenizer.pad_token_id, -100) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_atts = torch.ones( + query_tokens.size()[:-1], dtype=torch.long).to(self.device) + attention_mask = torch.cat([query_atts, text_attn_mask], dim=1) + lm_output = self.multimodal_backbone( + decoder_input_ids, + attention_mask=attention_mask, + past_key_values=query_output.past_key_values, + return_dict=True, + labels=labels, + ) + + return dict( + itc_loss=itc_loss, **loss_multimodal, lm_loss=lm_output.loss) + + def predict_all(self, + feats: Dict[str, torch.Tensor], + data_samples: List[DataSample], + num_images: int = None, + num_texts: int = None, + cal_i2t: bool = True, + cal_t2i: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute similarity matrix between images and texts across all ranks. + + Args: + feats (Dict[str, torch.Tensor]): Features from the current rank. + data_samples (List[DataSample]): Data samples from the current + rank. + num_images (int, optional): Number of images to use. + Defaults to None. + num_texts (int, optional): Number of texts to use. + Defaults to None. + cal_i2t (bool, optional): Whether to compute image-to-text + similarity. Defaults to True. + cal_t2i (bool, optional): Whether to compute text-to-image + similarity. Defaults to True. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Image-to-text and text-to-image + similarity matrices. + """ + text_ids = feats['text_ids'] + text_attn_mask = feats['text_attn_mask'] + image_embeds = feats.get('image_embeds', None) + image_feat = feats['image_feat'] + text_feat = feats['text_feat'] + + num_images = num_images or image_feat.size(0) + num_texts = num_texts or text_feat.size(0) + + if not self.fast_match: + image_embeds_all = all_gather_concat(image_embeds)[:num_images] + else: + image_embeds_all = None + image_feat_all = all_gather_concat(image_feat)[:num_images] + text_feat_all = all_gather_concat(text_feat)[:num_texts] + text_ids_all = all_gather_concat(text_ids)[:num_texts] + text_attn_mask_all = all_gather_concat(text_attn_mask)[:num_texts] + + results = [] + if cal_i2t: + result_i2t = self.compute_score_matrix_i2t( + image_feat, + image_embeds, + text_feat_all, + text_ids_all, + text_attn_mask_all, + ) + results.append( + self._get_predictions(result_i2t, data_samples, mode='i2t')) + if cal_t2i: + result_t2i = self.compute_score_matrix_t2i( + image_feat_all, + image_embeds_all, + text_feat, + text_ids, + text_attn_mask, + ) + results.append( + self._get_predictions(result_t2i, data_samples, mode='t2i')) + return tuple(results) + + def compute_score_matrix_i2t(self, img_feats: torch.Tensor, + img_embeds: List[torch.Tensor], + text_feats: torch.Tensor, + text_ids: torch.Tensor, + text_atts: torch.Tensor) -> torch.Tensor: + """Compare the score matrix for image-to-text retrieval. Every image + should compare to all the text features. + + Args: + img_feats (torch.Tensor): The input tensor with shape (M, C). + M stands for numbers of samples on a single GPU. + img_embeds (List[torch.Tensor]): Image features from each layer of + the vision backbone. + text_feats (torch.Tensor): The input tensor with shape (N, C). + N stands for numbers of all samples on all GPUs. + text_ids (torch.Tensor): The input tensor with shape (N, C). + text_atts (torch.Tensor): The input tensor with shape (N, C). + + Returns: + torch.Tensor: Score matrix of image-to-text retrieval. + """ + + # compute i2t sim matrix + # TODO: check correctness + sim_matrix_i2t, _ = (img_feats @ text_feats.t()).max(1) + if self.fast_match: + return sim_matrix_i2t + + score_matrix_i2t = torch.full((img_feats.size(0), text_feats.size(0)), + -100.0).to(self.device) + + for i in track_iter_progress(range(img_feats.size(0))): + sims = sim_matrix_i2t[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + # get repeated image embeddings + encoder_output = img_embeds[i].repeat(self.topk, 1, 1) + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + # query embeds and attention masks + query_tokens = self.query_tokens.expand(encoder_output.shape[0], + -1, -1) + query_atts = torch.ones( + query_tokens.size()[:-1], dtype=torch.long).to(self.device) + attention_mask = torch.cat([query_atts, text_atts[topk_idx]], + dim=1) + output = self.multimodal_backbone.bert( + text_ids[topk_idx], + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, :query_tokens.size(1), :], + ))[:, :, 1].mean(dim=1) + score_matrix_i2t[i, topk_idx] = score + topk_sim + + return score_matrix_i2t + + def compute_score_matrix_t2i(self, img_feats: torch.Tensor, + img_embeds: List[torch.Tensor], + text_feats: torch.Tensor, + text_ids: torch.Tensor, + text_atts: torch.Tensor) -> torch.Tensor: + """Compare the score matrix for text-to-image retrieval. + + Every text should compare to all the image features. + + Args: + img_feats (torch.Tensor): The input tensor with shape (N, C). + N stands for numbers of all samples on all GPUs. + img_embeds (List[torch.Tensor]): Image features from each layer of + the vision backbone. + text_feats (torch.Tensor): The input tensor with shape (M, C). + M stands for numbers of samples on a single GPU. + text_ids (torch.Tensor): The input tensor with shape (M, C). + text_atts (torch.Tensor): The input tensor with shape (M, C). + + Returns: + torch.Tensor: Score matrix of text-to-image retrieval. + """ + + # compute t2i sim matrix + # TODO: check correctness + sim_matrix_i2t, _ = (img_feats @ text_feats.t()).max(1) + sim_matrix_t2i = sim_matrix_i2t.t() + if self.fast_match: + return sim_matrix_i2t + + score_matrix_t2i = torch.full((text_feats.size(0), img_feats.size(0)), + -100.0).to(self.device) + + for i in track_iter_progress(range(text_feats.size(0))): + sims = sim_matrix_t2i[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + # get topk image embeddings + encoder_output = img_embeds[topk_idx] + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + # get query embeds and attention masks + query_tokens = self.query_tokens.expand(encoder_output.shape[0], + -1, -1) + query_atts = torch.ones( + query_tokens.size()[:-1], dtype=torch.long).to(self.device) + attention_mask = torch.cat( + [query_atts, text_atts[i].repeat(self.topk, 1)], dim=1) + output = self.multimodal_backbone.bert( + text_ids[i].repeat(self.topk, 1), + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = self.multimodal_head( + (output.last_hidden_state[:, :query_tokens.size(1), :], + ))[:, :, 1].mean(dim=1) + score_matrix_t2i[i, topk_idx] = score + topk_sim + + return score_matrix_t2i diff --git a/mmpretrain/models/multimodal/blip2/modeling_opt.py b/mmpretrain/models/multimodal/blip2/modeling_opt.py new file mode 100644 index 0000000..7cde0d7 --- /dev/null +++ b/mmpretrain/models/multimodal/blip2/modeling_opt.py @@ -0,0 +1,1083 @@ +# flake8: noqa +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OPT model.""" +import random +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.opt.configuration_opt import OPTConfig +from transformers.utils import (add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings) + +from mmpretrain.models.utils import register_hf_model + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = 'facebook/opt-350m' +_CONFIG_FOR_DOC = 'OPTConfig' +_TOKENIZER_FOR_DOC = 'GPT2Tokenizer' + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'facebook/opt-125m', + 'facebook/opt-350m', + 'facebook/opt-1.3b', + 'facebook/opt-2.7b', + 'facebook/opt-6.7b', + 'facebook/opt-13b', + 'facebook/opt-30b', + # See all OPT models at https://huggingface.co/models?filter=opt +] + + +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + past_key_values_length: int = 0): + """Make causal mask used for bi-directional self-attention.""" + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], + dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, + tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + """Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, + src_seq_len]`.""" + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), + torch.finfo(dtype).min) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """This module learns positional embeddings up to a fixed maximum size.""" + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = ( + torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * + attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}' + f' and `num_heads`: {num_heads}).') + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return (tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous()) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel.""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, + bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f'Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is' + f' {attn_weights.size()}') + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}' + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask) + attn_weights = torch.max( + attn_weights, + torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads, ): + raise ValueError( + f'Head mask for a single layer should be of size {(self.num_heads,)}, but is' + f' {layer_head_mask.size()}') + attn_weights = layer_head_mask.view( + 1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, + tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, + tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, + self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is' + f' {attn_output.size()}') + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, + self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class OPTDecoderLayer(nn.Module): + + def __init__(self, config: OPTConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = OPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (self_attn_weights, ) + + if use_cache: + outputs += (present_key_value, ) + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + 'The bare OPT Model outputting raw hidden-states without any specific head on top.', + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + + config_class = OPTConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['OPTDecoderLayer'] + _keys_to_ignore_on_load_unexpected = [r'decoder\.version'] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (OPTDecoder)): + module.gradient_checkpointing = value + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """Transformer decoder consisting of *config.num_hidden_layers* layers. + Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.word_embed_proj_dim, + self.padding_idx) + self.embed_positions = OPTLearnedPositionalEmbedding( + config.max_position_embeddings, config.hidden_size) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear( + config.hidden_size, config.word_embed_proj_dim, bias=False) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear( + config.word_embed_proj_dim, config.hidden_size, bias=False) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList( + [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, + inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + ).to(inputs_embeds.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, + tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else + expanded_attn_mask + combined_attention_mask) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time' + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + 'You have to specify either decoder_input_ids or decoder_inputs_embeds' + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if query_embeds is not None: + inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1) + input_shape = inputs_embeds.size()[:-1] + else: + input_shape = (batch_size, seq_length) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + inputs_embeds.shape[:2], + dtype=torch.bool, + device=inputs_embeds.device) + pos_embeds = self.embed_positions(attention_mask, + past_key_values_length) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ['head_mask']): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f'The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for' + f' {head_mask.size()[0]}.') + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None) + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] + if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += ( + layer_outputs[2 if output_attentions else 1], ) + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v for v in + [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@add_start_docstrings( + 'The bare OPT Model outputting raw hidden-states without any specific head on top.', + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.decoder = OPTDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + ) + + +@register_hf_model() +class OPTForCausalLM(OPTPreTrainedModel): + _keys_to_ignore_on_load_missing = [r'lm_head.weight'] + + def __init__(self, config): + super().__init__(config) + self.model = OPTModel(config) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear( + config.word_embed_proj_dim, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = 'mean', + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import GPT2Tokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + logits = logits[:, -labels.size(1):, :] + + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + loss = loss_fct( + shift_logits.view(-1, self.config.vocab_size), + shift_labels.view(-1)) + if reduction == 'none': + loss = loss.view(shift_logits.size(0), -1).sum(1) + + if not return_dict: + output = (logits, ) + outputs[1:] + return (loss, ) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids=None, + inputs_embeds=None, + query_embeds=None, + past_key_values=None, + attention_mask=None, + use_cache=None, + **kwargs, + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + if input_ids is not None: + attention_mask = input_ids.new_ones(input_ids.shape) + if past_key_values: + input_ids = input_ids[:, -1:] + query_embeds = None + # first step, decoder_cached_states are empty + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update({ + 'query_embeds': query_embeds, + 'attention_mask': attention_mask, + 'past_key_values': past_key_values, + 'use_cache': use_cache, + }) + return model_inputs + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past diff --git a/mmpretrain/models/multimodal/chinese_clip/__init__.py b/mmpretrain/models/multimodal/chinese_clip/__init__.py new file mode 100644 index 0000000..460e9e6 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .bert import BertModelCN +from .chinese_clip import ChineseCLIP, ModifiedResNet + +__all__ = ['ChineseCLIP', 'ModifiedResNet', 'BertModelCN'] diff --git a/mmpretrain/models/multimodal/chinese_clip/bert.py b/mmpretrain/models/multimodal/chinese_clip/bert.py new file mode 100644 index 0000000..4e8dc73 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/bert.py @@ -0,0 +1,263 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. + +# flake8: noqa +import math + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +try: + from transformers.models.bert.configuration_bert import BertConfig +except: + BertConfig = None + +from mmpretrain.registry import MODELS +from ..blip.language_model import BertAttention, BertIntermediate, BertOutput + + +def gelu(x): + """Original Implementation of the gelu activation function in Google Bert + repo when initially created. + + For information: OpenAI GPT's gelu is slightly different (and gives + slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ # noqa + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def gelu_new(x): + """Implementation of the gelu activation function currently in Google Bert + repo (identical to OpenAI GPT) https://arxiv.org/abs/1606.08415.""" + return 0.5 * x * (1 + torch.tanh( + math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = { + 'gelu': gelu, + 'relu': torch.nn.functional.relu, + 'swish': swish, + 'gelu_new': gelu_new +} + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type + embeddings.""" + + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model + # variable name and be able to load any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, position_ids=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings \ + + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertLayer(nn.Module): + + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + attention_outputs = self.attention(hidden_states, attention_mask, + head_mask) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output, ) + attention_outputs[ + 1:] # add attentions if we output them + if len(outputs) == 1: + return outputs[0] + return outputs + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super(BertEncoder, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.grad_checkpointing = False + self.layer = nn.ModuleList( + [BertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if self.grad_checkpointing and not torch.jit.is_scripting(): + layer_outputs = checkpoint(layer_module, hidden_states, + attention_mask, head_mask[i]) + else: + layer_outputs = layer_module(hidden_states, attention_mask, + head_mask[i]) + if not isinstance(layer_outputs, tuple): + layer_outputs = (layer_outputs, ) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + outputs = (hidden_states, ) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states, ) + if self.output_attentions: + outputs = outputs + (all_attentions, ) + # last-layer hidden state, (all hidden states), (all attentions) + return outputs + + +class BertPreTrainedModel(nn.Module): + base_model_prefix = 'bert' + + def __init__(self, config): + super(BertPreTrainedModel, self).__init__() + self.config = config + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version + # which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +@MODELS.register_module() +class BertModelCN(BertPreTrainedModel): + """The BERT model implementation for Chinese CLIP.""" + + def __init__(self, config): + config = BertConfig.from_dict(config) + super(BertModelCN, self).__init__(config) + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + self.apply(self._init_weights) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + if enable: + assert not self.config.output_attentions, \ + 'Grad checkpointing is currently conflict with ' \ + 'output_attentions for BertEncoder, ' \ + 'please set it to False in BertConfig' + + self.encoder.grad_checkpointing = enable + + def forward(self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze( + -1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, + -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1) # We can specify head_mask for each layer + head_mask = head_mask.to(dtype=next(self.parameters( + )).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) + encoder_outputs = self.encoder( + embedding_output, extended_attention_mask, head_mask=head_mask) + sequence_output = encoder_outputs[0] + # pooled_output = self.pooler(sequence_output) + pooled_output = None + + # add hidden_states and attentions if they are here + outputs = ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + # sequence_output, pooled_output, (hidden_states), (attentions) + return outputs diff --git a/mmpretrain/models/multimodal/chinese_clip/chinese_clip.py b/mmpretrain/models/multimodal/chinese_clip/chinese_clip.py new file mode 100644 index 0000000..40af564 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/chinese_clip.py @@ -0,0 +1,446 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.model import BaseModel, BaseModule +from torch import nn + +from mmpretrain.datasets.categories import CIFAR100_CATEGORIES_CN +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process +from .utils import OPENAI_PROMPT + +PROTOTYPE_MAP = {'cifar100': CIFAR100_CATEGORIES_CN} +PROMPT_MAP = {'openai': OPENAI_PROMPT} + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, + 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + return x[0] + + +@MODELS.register_module() +class ModifiedResNet(BaseModule): + """A modified ResNet contains the following changes: + + - Apply deep stem with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is + prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ # noqa + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth: int = 50, + base_channels: int = 64, + input_size: int = 224, + num_attn_heads: int = 32, + output_dim: int = 1024, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + self.input_size = input_size + self.block, stage_blocks = self.arch_settings[depth] + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, + base_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(base_channels // 2) + self.conv2 = nn.Conv2d( + base_channels // 2, + base_channels // 2, + kernel_size=3, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(base_channels // 2) + self.conv3 = nn.Conv2d( + base_channels // 2, + base_channels, + kernel_size=3, + padding=1, + bias=False) + self.bn3 = nn.BatchNorm2d(base_channels) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + # this is a *mutable* variable used during construction + self._inplanes = base_channels + self.layer1 = self._make_layer(base_channels, stage_blocks[0]) + self.layer2 = self._make_layer( + base_channels * 2, stage_blocks[1], stride=2) + self.layer3 = self._make_layer( + base_channels * 4, stage_blocks[2], stride=2) + self.layer4 = self._make_layer( + base_channels * 8, stage_blocks[3], stride=2) + + embed_dim = base_channels * 32 + self.attnpool = AttentionPool2d(input_size // 32, embed_dim, + num_attn_heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +@MODELS.register_module() +class ChineseCLIP(BaseModel): + """The implementation of `ChineseCLIP `_. + + Args: + vision_backbone (dict): Config dict for vision backbone. + text_backbone (dict): Config dict for text backbone. + tokenizer (dict): Config dict for text tokenizer. + proj_dim (int): Projection dimension for similarity computation. + text_prototype (str): Text prototype, which can be a key in + `PROTOTYPE_MAP` or list of text. + text_prompt (str): The prompt for text prototype. Defaults to 'openai'. + context_length (int): The context length to use. Defaults to 52. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + text_backbone: dict, + tokenizer: dict, + proj_dim: int, + text_prototype: Union[str, List[str]], + text_prompt: str = 'openai', + context_length: int = 52, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.vision_backbone = MODELS.build(vision_backbone) + self.text_backbone = MODELS.build(text_backbone) + + if not isinstance(self.vision_backbone, ModifiedResNet): + self.vision_projection = nn.Parameter( + torch.empty(self.vision_backbone.embed_dims, proj_dim)) + text_hidden_size = text_backbone['config']['hidden_size'] + self.text_projection = nn.Parameter( + torch.empty(text_hidden_size, proj_dim)) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.context_length = context_length + + # for zero-shot classification + if isinstance(text_prototype, + str) and text_prototype in PROTOTYPE_MAP.keys(): + self.prototype = PROTOTYPE_MAP[text_prototype] + else: + self.prototype = text_prototype + self.text_prototype_embeds = None + + self.prompt = PROMPT_MAP[text_prompt] + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor: + """The function to extract image latent features.""" + if isinstance(self.vision_backbone, ModifiedResNet): + return self.vision_backbone(images) + return self.vision_backbone(images)[-1] @ self.vision_projection + + def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor: + """The function to extract text latent features.""" + pad_index = self.tokenizer.vocab['[PAD]'] + attn_mask = texts.ne(pad_index) + # [batch_size, seq_length, hidden_size] + x = self.text_backbone(texts, attention_mask=attn_mask)[0] + return x[:, 0, :] @ self.text_projection + + def extract_feat( + self, images: torch.Tensor, + texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """The function to extract image and text latent features, the input + image or text can not both be None.""" + + assert images is not None or texts is not None, \ + 'text and image cannot both be None!' + if images is None: + return self.extract_text_feat(texts) + elif texts is None: + return self.extract_image_feat(images) + + image_features = self.extract_image_feat(images) + text_features = self.extract_text_feat(texts) + + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + return image_features, text_features + + def compute_similarity(self, images, texts): + """Extract images and texts features and compute cosine similarity.""" + image_features, text_features = self.extract_feat( + images=images, texts=texts) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape (N, N) + return logits_per_image, logits_per_text + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + """Predict the classes of the input images. + + The prediction is for zero-shot classification and the text prototypes + will be prepared in thisfunction. + + Args: + images (torch.Tensor): The input images. + data_samples (DataSample): The data samples with information from + dataset. + + Returns: + DataSample: The results of prediction. + """ + + if self.text_prototype_embeds is None: + self.prepare_text_prototype(device=images.device) + + image_features = self.extract_image_feat(images=images) + image_features /= image_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_image = image_features @ self.text_prototype_embeds.to( + image_features.device) * self.logit_scale.exp() + + pred_scores = F.softmax(logits_per_image, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = DataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples + + def prepare_text_prototype(self, device) -> None: + """The function to prepare text prototypes with prompt.""" + class_embeddings = [] + for classname in track_on_main_process(self.prototype, + 'Prepare text prototype...'): + # format with class + texts = [prompt(classname) for prompt in self.prompt] + tokenized_texts = self.tokenize(texts) + class_features = self.extract_text_feat(tokenized_texts.to(device)) + class_features /= class_features.norm(dim=-1, keepdim=True) + class_feature = class_features.mean(dim=0) + class_feature /= class_feature.norm() + class_embeddings.append(class_feature) + self.text_prototype_embeds = torch.stack( + class_embeddings, dim=1).to(device) + + def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Args: + texts (Union[str, List[str]]): An input string or a list of input + strings to tokenize + context_length (int): The context length to use. Defaults to 52. + + Returns: + torch.Tensor: Resulting tokens. + """ + if isinstance(texts, str): + texts = [texts] + + all_tokens = [] + for text in texts: + # adapt the text to Chinese BERT vocab + text = text.lower().replace('“', "\"").replace('”', "\"") + + # add special tokens + all_tokens.append( + [self.tokenizer.vocab['[CLS]']] + + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text))[:self.context_length - 2] + + [self.tokenizer.vocab['[SEP]']]) + + result = torch.zeros( + len(all_tokens), self.context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + assert len(tokens) <= self.context_length + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/mmpretrain/models/multimodal/chinese_clip/utils.py b/mmpretrain/models/multimodal/chinese_clip/utils.py new file mode 100644 index 0000000..6964722 --- /dev/null +++ b/mmpretrain/models/multimodal/chinese_clip/utils.py @@ -0,0 +1,186 @@ +# Copyright (c) OpenMMLab. All rights reserved. +OPENAI_PROMPT = [ + lambda c: f'{c}的照片', + lambda c: f'质量差的{c}的照片', + lambda c: f'许多{c}的照片', + lambda c: f'{c}的雕塑', + lambda c: f'难以看到{c}的照片', + lambda c: f'{c}的低分辨率照片', + lambda c: f'{c}的渲染', + lambda c: f'涂鸦{c}', + lambda c: f'{c}的糟糕照片', + lambda c: f'{c}的裁剪照片', + lambda c: f'{c}的纹身', + lambda c: f'{c}的刺绣照片', + lambda c: f'很难看到{c}的照片', + lambda c: f'{c}的明亮照片', + lambda c: f'一张干净的{c}的照片', + lambda c: f'一张包含{c}的照片', + lambda c: f'{c}的深色照片', + lambda c: f'{c}的手绘画', + lambda c: f'我的{c}的照片', + lambda c: f'不自然的{c}的照片', + lambda c: f'一张酷的{c}的照片', + lambda c: f'{c}的特写照片', + lambda c: f'{c}的黑白照片', + lambda c: f'一幅{c}的画', + lambda c: f'一幅{c}的绘画', + lambda c: f'一张{c}的像素照片', + lambda c: f'{c}的雕像', + lambda c: f'一张{c}的明亮照片', + lambda c: f'{c}的裁剪照片', + lambda c: f'人造的{c}的照片', + lambda c: f'一张关于{c}的照片', + lambda c: f'损坏的{c}的jpeg照片', + lambda c: f'{c}的模糊照片', + lambda c: f'{c}的相片', + lambda c: f'一张{c}的好照片', + lambda c: f'{c}的渲染照', + lambda c: f'视频游戏中的{c}', + lambda c: f'一张{c}的照片', + lambda c: f'{c}的涂鸦', + lambda c: f'{c}的近距离照片', + lambda c: f'{c}的折纸', + lambda c: f'{c}在视频游戏中', + lambda c: f'{c}的草图', + lambda c: f'{c}的涂鸦照', + lambda c: f'{c}的折纸形状', + lambda c: f'低分辨率的{c}的照片', + lambda c: f'玩具{c}', + lambda c: f'{c}的副本', + lambda c: f'{c}的干净的照片', + lambda c: f'一张大{c}的照片', + lambda c: f'{c}的重现', + lambda c: f'一张漂亮的{c}的照片', + lambda c: f'一张奇怪的{c}的照片', + lambda c: f'模糊的{c}的照片', + lambda c: f'卡通{c}', + lambda c: f'{c}的艺术作品', + lambda c: f'{c}的素描', + lambda c: f'刺绣{c}', + lambda c: f'{c}的像素照', + lambda c: f'{c}的拍照', + lambda c: f'{c}的损坏的照片', + lambda c: f'高质量的{c}的照片', + lambda c: f'毛绒玩具{c}', + lambda c: f'漂亮的{c}的照片', + lambda c: f'小{c}的照片', + lambda c: f'照片是奇怪的{c}', + lambda c: f'漫画{c}', + lambda c: f'{c}的艺术照', + lambda c: f'{c}的图形', + lambda c: f'大{c}的照片', + lambda c: f'黑白的{c}的照片', + lambda c: f'{c}毛绒玩具', + lambda c: f'一张{c}的深色照片', + lambda c: f'{c}的摄影图', + lambda c: f'{c}的涂鸦照', + lambda c: f'玩具形状的{c}', + lambda c: f'拍了{c}的照片', + lambda c: f'酷酷的{c}的照片', + lambda c: f'照片里的小{c}', + lambda c: f'{c}的刺青', + lambda c: f'{c}的可爱的照片', + lambda c: f'一张{c}可爱的照片', + lambda c: f'{c}可爱图片', + lambda c: f'{c}酷炫图片', + lambda c: f'一张{c}的酷炫的照片', + lambda c: f'一张{c}的酷炫图片', + lambda c: f'这是{c}', + lambda c: f'{c}的好看照片', + lambda c: f'一张{c}的好看的图片', + lambda c: f'{c}的好看图片', + lambda c: f'{c}的照片。', + lambda c: f'质量差的{c}的照片。', + lambda c: f'许多{c}的照片。', + lambda c: f'{c}的雕塑。', + lambda c: f'难以看到{c}的照片。', + lambda c: f'{c}的低分辨率照片。', + lambda c: f'{c}的渲染。', + lambda c: f'涂鸦{c}。', + lambda c: f'{c}的糟糕照片。', + lambda c: f'{c}的裁剪照片。', + lambda c: f'{c}的纹身。', + lambda c: f'{c}的刺绣照片。', + lambda c: f'很难看到{c}的照片。', + lambda c: f'{c}的明亮照片。', + lambda c: f'一张干净的{c}的照片。', + lambda c: f'一张包含{c}的照片。', + lambda c: f'{c}的深色照片。', + lambda c: f'{c}的手绘画。', + lambda c: f'我的{c}的照片。', + lambda c: f'不自然的{c}的照片。', + lambda c: f'一张酷的{c}的照片。', + lambda c: f'{c}的特写照片。', + lambda c: f'{c}的黑白照片。', + lambda c: f'一幅{c}的画。', + lambda c: f'一幅{c}的绘画。', + lambda c: f'一张{c}的像素照片。', + lambda c: f'{c}的雕像。', + lambda c: f'一张{c}的明亮照片。', + lambda c: f'{c}的裁剪照片。', + lambda c: f'人造的{c}的照片。', + lambda c: f'一张关于{c}的照片。', + lambda c: f'损坏的{c}的jpeg照片。', + lambda c: f'{c}的模糊照片。', + lambda c: f'{c}的相片。', + lambda c: f'一张{c}的好照片。', + lambda c: f'{c}的渲染照。', + lambda c: f'视频游戏中的{c}。', + lambda c: f'一张{c}的照片。', + lambda c: f'{c}的涂鸦。', + lambda c: f'{c}的近距离照片。', + lambda c: f'{c}的折纸。', + lambda c: f'{c}在视频游戏中。', + lambda c: f'{c}的草图。', + lambda c: f'{c}的涂鸦照。', + lambda c: f'{c}的折纸形状。', + lambda c: f'低分辨率的{c}的照片。', + lambda c: f'玩具{c}。', + lambda c: f'{c}的副本。', + lambda c: f'{c}的干净的照片。', + lambda c: f'一张大{c}的照片。', + lambda c: f'{c}的重现。', + lambda c: f'一张漂亮的{c}的照片。', + lambda c: f'一张奇怪的{c}的照片。', + lambda c: f'模糊的{c}的照片。', + lambda c: f'卡通{c}。', + lambda c: f'{c}的艺术作品。', + lambda c: f'{c}的素描。', + lambda c: f'刺绣{c}。', + lambda c: f'{c}的像素照。', + lambda c: f'{c}的拍照。', + lambda c: f'{c}的损坏的照片。', + lambda c: f'高质量的{c}的照片。', + lambda c: f'毛绒玩具{c}。', + lambda c: f'漂亮的{c}的照片。', + lambda c: f'小{c}的照片。', + lambda c: f'照片是奇怪的{c}。', + lambda c: f'漫画{c}。', + lambda c: f'{c}的艺术照。', + lambda c: f'{c}的图形。', + lambda c: f'大{c}的照片。', + lambda c: f'黑白的{c}的照片。', + lambda c: f'{c}毛绒玩具。', + lambda c: f'一张{c}的深色照片。', + lambda c: f'{c}的摄影图。', + lambda c: f'{c}的涂鸦照。', + lambda c: f'玩具形状的{c}。', + lambda c: f'拍了{c}的照片。', + lambda c: f'酷酷的{c}的照片。', + lambda c: f'照片里的小{c}。', + lambda c: f'{c}的刺青。', + lambda c: f'{c}的可爱的照片。', + lambda c: f'一张{c}可爱的照片。', + lambda c: f'{c}可爱图片。', + lambda c: f'{c}酷炫图片。', + lambda c: f'一张{c}的酷炫的照片。', + lambda c: f'一张{c}的酷炫图片。', + lambda c: f'这是{c}。', + lambda c: f'{c}的好看照片。', + lambda c: f'一张{c}的好看的图片。', + lambda c: f'{c}的好看图片。', + lambda c: f'一种叫{c}的花的照片', + lambda c: f'一种叫{c}的食物的照片', + lambda c: f'{c}的卫星照片', +] diff --git a/mmpretrain/models/multimodal/clip/__init__.py b/mmpretrain/models/multimodal/clip/__init__.py new file mode 100644 index 0000000..f7a117e --- /dev/null +++ b/mmpretrain/models/multimodal/clip/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..clip.clip import CLIP, CLIPZeroShot +from ..clip.clip_transformer import CLIPProjection, CLIPTransformer + +__all__ = ['CLIP', 'CLIPZeroShot', 'CLIPTransformer', 'CLIPProjection'] diff --git a/mmpretrain/models/multimodal/clip/clip.py b/mmpretrain/models/multimodal/clip/clip.py new file mode 100644 index 0000000..b509a63 --- /dev/null +++ b/mmpretrain/models/multimodal/clip/clip.py @@ -0,0 +1,364 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.datasets.categories import (CIFAR100_CATEGORIES, + IMAGENET_SIMPLE_CATEGORIES) +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process +from .utils import (OPENAI_CIFAR100_PROMPT, OPENAI_IMAGENET_PROMPT, + OPENAI_IMAGENET_PROMPT_SUB) + +CIFAR100_CATEGORIES = [' '.join(c.split('_')) for c in CIFAR100_CATEGORIES] +PROTOTYPE_MAP = { + 'imagenet': IMAGENET_SIMPLE_CATEGORIES, + 'cifar100': CIFAR100_CATEGORIES, +} +PROMPT_MAP = { + 'openai_imagenet': OPENAI_IMAGENET_PROMPT, + 'openai_cifar100': OPENAI_CIFAR100_PROMPT, + 'vanilla': [lambda c: f'a photo of a {c}'], + 'openai_imagenet_sub': OPENAI_IMAGENET_PROMPT_SUB +} + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class CLIP(BaseModel): + """The implementation of `CLIP `_. + + Args: + vision_backbone (dict): Config dict for vision backbone. + text_backbone (dict): Config dict for text backbone. + tokenizer (dict): Config dict for text tokenizer. + proj_dim (int): Projection dimension for similarity computation. + text_prototype (str): Text prototype, which can be a key in + `PROTOTYPE_MAP` or list of text. + text_prompt (str): The prompt for text prototype. + Defaults to 'vanilla',which refers to "a photo of {cls}". + context_length (int): The context length to use. Defaults to 77. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + projection: dict, + text_backbone: dict, + tokenizer: dict, + vocab_size: int, + transformer_width: int, + proj_dim: int, + context_length: int = 77, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.context_length = context_length + + # build the vision transformer + self.visual = MODELS.build(vision_backbone) + + # build the visual projection + self.visual_proj = MODELS.build(projection) + + # build attn_mask for casual-attn + text_backbone['attn_mask'] = self.build_attention_mask() + + # build the text transformer + self.transformer = MODELS.build(text_backbone) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, proj_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + self.tokenizer = TOKENIZER.build(tokenizer) + + self.tokenizer.vocab = self.tokenizer.get_vocab( + ) # CLIPTokenizer has no attribute named 'vocab', so manually + + def initialize_parameters(self) -> None: + """Initialize the parameters. + + The pretrained weight will override the initialized parameters by this + function. + """ + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, + # with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor: + """The function to extract image latent features.""" + return self.visual_proj(self.visual(images))[0] + + def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor: + """The function to extract text latent features.""" + x = self.token_embedding(texts) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x)[0] + + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding + # (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), + texts.argmax(dim=-1)] @ self.text_projection + + return x + + def extract_feat( + self, images: torch.Tensor, + texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """The function to extract image and text latent features, the input + image or text can not both be None.""" + + assert images is not None or texts is not None, \ + 'text and image cannot both be None!' + if images is None: + return self.extract_text_feat(texts) + elif texts is None: + return self.extract_image_feat(images) + + image_features = self.extract_image_feat(images) + text_features = self.extract_text_feat(texts) + + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + return image_features, text_features + + def compute_similarity(self, images, texts): + """Extract images and texts features and compute cosine similarity.""" + image_features, text_features = self.extract_feat( + images=images, texts=texts) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape (N, N) + return logits_per_image, logits_per_text + + @abstractmethod + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + raise NotImplementedError + + def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Args: + texts (Union[str, List[str]]): An input string or a list of input + strings to tokenize + context_length (int): The context length to use. Defaults to 52. + + Returns: + torch.Tensor: Resulting tokens. + """ + if isinstance(texts, str): + texts = [texts] + + all_tokens = [] + for text in texts: + # adapt the text to Chinese BERT vocab + # text = text.lower().replace('“', "\"").replace('”', "\"") + + # add special tokens + all_tokens.append( + [self.tokenizer.vocab['<|startoftext|>'] + ] + # <|startoftext|>代表[CLS] token + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text))[:self.context_length - 2] + + [self.tokenizer.vocab['<|endoftext|>']]) + + result = torch.zeros( + len(all_tokens), self.context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + assert len(tokens) <= self.context_length + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +@MODELS.register_module() +class CLIPZeroShot(CLIP): + + def __init__( + self, + vision_backbone: dict, + projection: dict, + text_backbone: dict, + tokenizer: dict, + vocab_size: int, + transformer_width: int, + proj_dim: int, + context_length: int = 77, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + text_prototype: Union[str, List[str]] = 'imagenet', + text_prompt: str = 'vanilla', + ): + super(CLIPZeroShot, + self).__init__(vision_backbone, projection, text_backbone, + tokenizer, vocab_size, transformer_width, + proj_dim, context_length, data_preprocessor, + init_cfg) + + # for zero-shot classification + if isinstance(text_prototype, + str) and text_prototype in PROTOTYPE_MAP.keys(): + self.prototype = PROTOTYPE_MAP[text_prototype] + else: + self.prototype = text_prototype + self.text_prototype_embeds = None + + self.prompt = PROMPT_MAP[text_prompt] + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + """Predict the classes of the input images. + + The prediction is for zero-shot classification and the text prototypes + will be prepared in thisfunction. + + Args: + images (torch.Tensor): The input images. + data_samples (DataSample): The data samples with information from + dataset. + + Returns: + DataSample: The results of prediction. + """ + + if self.text_prototype_embeds is None: + self.prepare_text_prototype(device=images.device) + + image_features = self.extract_image_feat(images=images) + image_features /= image_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_image = image_features @ self.text_prototype_embeds.to( + image_features.device) * self.logit_scale.exp() + + pred_scores = F.softmax(logits_per_image, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = DataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples + + def prepare_text_prototype(self, device) -> None: + """The function to prepare text prototypes with prompt.""" + class_embeddings = [] + for classname in track_on_main_process(self.prototype, + 'Prepare text prototype...'): + # format with class + texts = [prompt(classname) for prompt in self.prompt] + tokenized_texts = self.tokenize(texts) + class_features = self.extract_text_feat(tokenized_texts.to(device)) + class_features /= class_features.norm(dim=-1, keepdim=True) + class_feature = class_features.mean(dim=0) + class_feature /= class_feature.norm() + class_embeddings.append(class_feature) + self.text_prototype_embeds = torch.stack( + class_embeddings, dim=1).to(device) diff --git a/mmpretrain/models/multimodal/clip/clip_transformer.py b/mmpretrain/models/multimodal/clip/clip_transformer.py new file mode 100644 index 0000000..4b5f766 --- /dev/null +++ b/mmpretrain/models/multimodal/clip/clip_transformer.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/zejiangh/MILAN +from typing import Optional, Tuple + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.models.utils.clip_generator_helper import \ + ResidualAttentionBlock +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CLIPTransformer(nn.Module): + """Transformer. + + Both visual and text branches use this transformer. + + Args: + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + """ + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: Optional[torch.Tensor] = None) -> None: + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList() + for _ in range(layers - 1): + self.resblocks.append( + ResidualAttentionBlock(width, heads, attn_mask)) + self.resblocks.append( + ResidualAttentionBlock( + width, heads, attn_mask, return_attention=True)) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward function.""" + z = [] + for idx, blk in enumerate(self.resblocks): + if idx < self.layers - 1: + x = blk(x) + z.append(x.permute(1, 0, 2)) + else: + x, attention = blk(x) + z.append(x.permute(1, 0, 2)) + return x, attention, z + + +@MODELS.register_module() +class CLIPProjection(BaseModule): + """Neck with CLIP Projection. + + Args: + in_channels (int): Number of channels in the input. + out_channels (int): Number of channels in the output. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + init_cfg: Optional[dict] = None): + super(CLIPProjection, self).__init__(init_cfg=init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + scale = in_channels**-0.5 + self.proj = nn.Parameter(scale * + torch.randn(in_channels, out_channels)) + + def forward(self, inputs: Tuple) -> Tuple[torch.Tensor]: + """forward function. + + Args: + inputs (Tuple): The features extracted from + the backbone. Multiple stage inputs are acceptable but only + the last stage will be used. + Returns: + Tuple(torch.Tensor)): A tuple of reducted features. + """ + if isinstance(inputs, tuple): + inputs = inputs[-1] + out = inputs @ self.proj + elif isinstance(inputs, torch.Tensor): + out = inputs @ self.proj + else: + raise TypeError( + '`CLIPProjection` neck inputs should be tuple or torch.tensor') + return (out, ) diff --git a/mmpretrain/models/multimodal/clip/utils.py b/mmpretrain/models/multimodal/clip/utils.py new file mode 100644 index 0000000..65239bc --- /dev/null +++ b/mmpretrain/models/multimodal/clip/utils.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +OPENAI_CIFAR100_PROMPT = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'a low contrast photo of a {c}.', + lambda c: f'a high contrast photo of a {c}.', + lambda c: f'a bad photo of a {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a photo of a big {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a low contrast photo of the {c}.', + lambda c: f'a high contrast photo of the {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the big {c}.', +] + +OPENAI_IMAGENET_PROMPT_SUB = [ + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +] + +OPENAI_IMAGENET_PROMPT = [ + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +] diff --git a/mmpretrain/models/multimodal/flamingo/__init__.py b/mmpretrain/models/multimodal/flamingo/__init__.py new file mode 100644 index 0000000..e0bfd63 --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .adapter import FlamingoLMAdapter +from .flamingo import Flamingo + +__all__ = ['Flamingo', 'FlamingoLMAdapter'] diff --git a/mmpretrain/models/multimodal/flamingo/adapter.py b/mmpretrain/models/multimodal/flamingo/adapter.py new file mode 100644 index 0000000..bef0e2f --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/adapter.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +import torch.nn as nn + +from mmpretrain.registry import MODELS +from .modules import FlamingoLayer, GatedCrossAttentionBlock +from .utils import getattr_recursive, setattr_recursive + + +@MODELS.register_module() +class FlamingoLMAdapter: + """Mixin to add cross-attention layers to a language model.""" + + @classmethod + def extend_init( + cls, + base: object, + vis_hidden_size: int, + cross_attn_every_n_layers: int, + use_media_placement_augmentation: bool, + only_attend_previous: bool = False, + ): + """Initialize Flamingo by adding a new gated cross attn to the decoder. + + Store the media token id for computing the media locations. + + Args: + base (object): Base module could be any object that represent + a instance of language model. + vis_hidden_size: (int): Hidden size of vision embeddings. + cross_attn_every_n_layers: (int): Additional cross attn for + every n layers. + use_media_placement_augmentation: (bool): Whether to use media + placement augmentation. + """ + base.set_decoder_layers_attr_name('model.layers') + gated_cross_attn_layers = nn.ModuleList([ + GatedCrossAttentionBlock( + dim=base.config.hidden_size, dim_visual=vis_hidden_size) if + (layer_idx + 1) % cross_attn_every_n_layers == 0 else None + for layer_idx, _ in enumerate(base._get_decoder_layers()) + ]) + base._set_decoder_layers( + nn.ModuleList([ + FlamingoLayer(gated_cross_attn_layer, decoder_layer) + for gated_cross_attn_layer, decoder_layer in zip( + gated_cross_attn_layers, base._get_decoder_layers()) + ])) + base.use_media_placement_augmentation = use_media_placement_augmentation # noqa + base.initialized_flamingo = True + base.only_attend_previous = only_attend_previous + return base + + def set_decoder_layers_attr_name(self, decoder_layers_attr_name): + """Set decoder layers attribute name.""" + self.decoder_layers_attr_name = decoder_layers_attr_name + + def _get_decoder_layers(self): + """Get decoder layers according to attribute name.""" + return getattr_recursive(self, self.decoder_layers_attr_name) + + def _set_decoder_layers(self, value): + """Set decoder layers according to attribute name.""" + setattr_recursive(self, self.decoder_layers_attr_name, value) + + def forward(self, *input, **kwargs): + """Condition the Flamingo layers on the media locations before forward + function.""" + input_ids = kwargs['input_ids'] if 'input_ids' in kwargs else input[0] + media_locations = input_ids == self.media_token_id + if self.only_attend_previous: + attend_previous = True + elif self.use_media_placement_augmentation: + attend_previous = (random.random() < 0.5) + else: + attend_previous = False + + for layer in self.get_decoder().layers: + layer.condition_media_locations(media_locations) + layer.condition_attend_previous(attend_previous) + + return super().forward( + *input, **kwargs) # Call the other parent's forward method + + def is_conditioned(self) -> bool: + """Check whether all decoder layers are already conditioned.""" + return all(layer.is_conditioned() + for layer in self._get_decoder_layers()) + + def clear_conditioned_layers(self): + """Clear all conditional layers.""" + for layer in self._get_decoder_layers(): + layer.condition_vis_x(None) + layer.condition_media_locations(None) + layer.condition_attend_previous(None) diff --git a/mmpretrain/models/multimodal/flamingo/flamingo.py b/mmpretrain/models/multimodal/flamingo/flamingo.py new file mode 100644 index 0000000..729d6c7 --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/flamingo.py @@ -0,0 +1,323 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import List, Optional + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from .modules import PerceiverResampler +from .utils import ExtendModule + + +@MODELS.register_module() +class Flamingo(BaseModel): + """The Open Flamingo model for multiple tasks. + + Args: + vision_encoder (dict): The config of the vision encoder. + lang_encoder (dict): The config of the language encoder. + tokenizer (dict): The tokenizer to encode the text. + task (int): The task to perform prediction. + zeroshot_prompt (str): Prompt used for zero-shot inference. + Defaults to 'Output:'. + shot_prompt_tmpl (str): Prompt used for few-shot inference. + Defaults to ``Output:{caption}<|endofchunk|>``. + final_prompt_tmpl (str): Final part of prompt used for inference. + Defaults to 'Output:'. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of [~`transformers.GenerationConfig`]. + Defaults to an empty dict. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + support_tasks = {'caption', 'vqa'} + _no_split_modules = [ + 'TransformerEncoderLayer', 'PerceiverAttention', + 'GatedCrossAttentionBlock', 'FlamingoLayer' + ] + + def __init__( + self, + vision_encoder: dict, + lang_encoder: dict, + tokenizer: dict, + task: str = 'caption', + zeroshot_prompt: str = 'Output:', + shot_prompt_tmpl: str = 'Output:{caption}<|endofchunk|>', + final_prompt_tmpl: str = 'Output:', + generation_cfg: dict = dict(), + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + self.task = task + + # init tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + # add Flamingo special tokens to the tokenizer + self.tokenizer.add_special_tokens( + {'additional_special_tokens': ['<|endofchunk|>', '']}) + self.tokenizer.bos_token_id = 1 + if self.tokenizer.pad_token is None: + # Issue: GPT models don't have a pad token, which we use to + # modify labels for the loss. + self.tokenizer.add_special_tokens({'pad_token': ''}) + + # Template to format the prompt input + self.zeroshot_prompt = zeroshot_prompt + self.shot_prompt_tmpl = shot_prompt_tmpl + self.final_prompt_tmpl = final_prompt_tmpl + + # init vision encoder related modules + vision_encoder_weight = vision_encoder.pop('pretrained', None) + self.vision_encoder = MODELS.build(vision_encoder) + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint( + self.vision_encoder, + vision_encoder_weight, + map_location='cpu', + revise_keys=[(r'^backbone\.', '')], + ) + self.vision_encoder.is_init = True + + self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims) + + # init language encoder related modules + self.lang_encoder = ExtendModule(**lang_encoder) + self.lang_encoder.resize_token_embeddings(len(self.tokenizer)) + self.lang_encoder.media_token_id = self.tokenizer.encode('')[-1] + + # other necessary parameters + self.eoc_token_id = self.tokenizer.encode('<|endofchunk|>')[-1] + self.generation_cfg = { + 'num_beams': 1, + 'max_new_tokens': None, + 'temperature': 1.0, + 'top_k': 0, + 'top_p': 1.0, + 'no_repeat_ngram_size': 0, + 'prefix_allowed_tokens_fn': None, + 'length_penalty': 1.0, + 'num_return_sequences': 1, + 'do_sample': False, + 'early_stopping': False, + **generation_cfg, + } + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_adapter_hook) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + The method should accept only one mode "loss": + + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): The input image tensor with different ndim + according to the inputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_vision_feats(self, images: torch.Tensor) -> torch.Tensor: + """Extract vision features. + + Args: + images (torch.Tensor): For zero-shot, the input images tensor is + with shape (B, C, H, W), for few-shot, which is + (B, T_img, C, H, W) in general. Images in the same chunk + are collated along T_img. Video data is not supported yet. + + Returns: + torch.Tensor: Return extracted features. + """ + if images.ndim == 4: + # (B, C, H, W) -> (B, 1, C, H, W) for zero-shot. + images = images.unsqueeze(1) + b, T = images.shape[:2] + # b T c h w -> (b T) c h w + images = images.view(b * T, *images.shape[-3:]) + + with torch.no_grad(): + vision_feats = self.vision_encoder(images)[-1][:, 1:] + + # (b T F) v d -> b T F v d Only support F=1 here + vision_feats = vision_feats.view(b, T, 1, *vision_feats.shape[-2:]) + + vision_feats = self.perceiver(vision_feats) # reshapes to (b, T, n, d) + return vision_feats + + def predict(self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + **generation_cfg): + """Predict generation results from a batch of inputs. + + Args: + images (torch.Tensor): For zero-shot, the input images tensor is + with shape (B, C, H, W), for few-shot, which is + (B, T_img, C, H, W) in general. Images in the same chunk + are collated along T_img. Video data is not supported yet. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **generation_cfg: Other keyword arguments accepted by the + ``generate`` method of :attr:`lang_encoder`. + + Returns: + List[DataSample]: Return list of data samples. + """ + # generation_cfg in prediction should be dominant + generation_cfg = {**self.generation_cfg, **generation_cfg} + num_beams = generation_cfg['num_beams'] + + if num_beams > 1: + images = images.repeat_interleave(num_beams, dim=0) + + # extra vision feats and set as language condition feats + vision_x = self.extract_vision_feats(images) + for layer in self.lang_encoder._get_decoder_layers(): + layer.condition_vis_x(vision_x) + + input_text = self.preprocess_text(data_samples, device=images.device) + + outputs = self.lang_encoder.generate( + input_text.input_ids, + attention_mask=input_text.attention_mask, + eos_token_id=self.eoc_token_id, + **generation_cfg) + + # clear conditioned layers for language models + self.lang_encoder.clear_conditioned_layers() + + # remove prefix + outputs = outputs[:, len(input_text.input_ids[0]):] + + return self.post_process(outputs, data_samples) + + def preprocess_text(self, data_samples: List[DataSample], + device: torch.device) -> List[DataSample]: + """Preprocess text in advance before fed into language model. + + Args: + data_samples (List[DataSample]): The annotation + data of every samples. Defaults to None. + device (torch.device): Device for text to put on. + + Returns: + List[DataSample]: Return list of data samples. + """ + prompts = [] + for sample in data_samples: + if 'shots' in sample: + # few-shot + shot_prompt = ''.join([ + self.shot_prompt_tmpl.format(**shot) + for shot in sample.get('shots') + ]) + else: + # zero-shot + shot_prompt = self.zeroshot_prompt + + # add final prompt + final_prompt = self.final_prompt_tmpl.format(**sample.to_dict()) + prompts.append(shot_prompt + final_prompt) + + self.tokenizer.padding_side = 'left' + input_text = self.tokenizer( + prompts, + padding='longest', + truncation=True, + return_tensors='pt', + max_length=2000, + ).to(device) + return input_text + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + # remove text pattern + if self.task == 'caption': + data_sample.pred_caption = re.split('Output', output, + 1)[0].replace('"', '') + elif self.task == 'vqa': + data_sample.pred_answer = re.split('Question|Answer', output, + 1)[0] + + return data_samples + + @staticmethod + def _load_adapter_hook(module, incompatible_keys): + """Avoid warning missing keys except adapter keys.""" + adapter_patterns = [ + '^perceiver', + 'lang_encoder.*embed_tokens', + 'lang_encoder.*gated_cross_attn_layers', + 'lang_encoder.*rotary_emb', + ] + for key in list(incompatible_keys.missing_keys): + if not any(re.match(pattern, key) for pattern in adapter_patterns): + incompatible_keys.missing_keys.remove(key) + + for key in list(incompatible_keys.unexpected_keys): + if 'position_ids' in key: + incompatible_keys.unexpected_keys.remove(key) + if 'lang_encoder.gated_cross_attn_layers' in key: + incompatible_keys.unexpected_keys.remove(key) diff --git a/mmpretrain/models/multimodal/flamingo/modules.py b/mmpretrain/models/multimodal/flamingo/modules.py new file mode 100644 index 0000000..730c61b --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/modules.py @@ -0,0 +1,398 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Taken from https://github.com/lucidrains/flamingo-pytorch.""" + +from typing import Optional + +import torch +from einops import rearrange, repeat +from torch import einsum, nn + + +def FeedForward(dim, mult: int = 4): + """Feedforward layers. + + Args: + mult (int): Layer expansion muliplier. Defaults to 4. + """ + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +class PerceiverAttention(nn.Module): + """Perceiver attetion layers. + + Args: + dim (int): Input dimensions. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + """ + + def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x: torch.Tensor, latents: torch.Tensor): + """Forward function. + + Args: + x (torch.Tensor): image features of shape (b, T, n1, D). + latent (torch.Tensor): latent features of shape (b, T, n2, D). + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + q = rearrange(q, 'b t n (h d) -> b h t n d', h=h) + k = rearrange(k, 'b t n (h d) -> b h t n d', h=h) + v = rearrange(v, 'b t n (h d) -> b h t n d', h=h) + q = q * self.scale + + # attention + sim = einsum('... i d, ... j d -> ... i j', q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum('... i j, ... j d -> ... i d', attn, v) + out = rearrange(out, 'b h t n d -> b t n (h d)', h=h) + return self.to_out(out) + + +class PerceiverResampler(nn.Module): + """Perceiver resampler layers. + + Args: + dim (int): Input dimensions. + depth (int): Depth of resampler. Defaults to 6. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + num_latents (int): Number of latents. Defaults to 64. + max_num_media (int, optional): Max number of media. + Defaults to None. + max_num_frames (int, optional): Max number of frames. + Defaults to None. + ff_mult (int): Feed forward multiplier. Defaults to 4. + """ + + def __init__( + self, + *, + dim: int, + depth: int = 6, + dim_head: int = 64, + heads: int = 8, + num_latents: int = 64, + max_num_media: Optional[int] = None, + max_num_frames: Optional[int] = None, + ff_mult: int = 4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + self.frame_embs = ( + nn.Parameter(torch.randn(max_num_frames, dim)) + if max_num_frames is not None else None) + self.media_time_embs = ( + nn.Parameter(torch.randn(max_num_media, 1, dim)) + if max_num_media is not None else None) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention( + dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + self.norm = nn.LayerNorm(dim) + + def forward(self, x: torch.Tensor): + """Forward function for perceiver sampler. + + Args: + x (torch.Tensor): image features of shape (b, T, F, v, D) + + Returns: + torch.Tensor: shape (b, T, n, D) where n is self.num_latents + """ + b, T, F, v = x.shape[:4] + + # frame and media time embeddings + if self.frame_embs is not None: + frame_embs = repeat( + self.frame_embs[:F], 'F d -> b T F v d', b=b, T=T, v=v) + x = x + frame_embs + x = rearrange(x, 'b T F v d -> b T (F v) d' + ) # flatten the frame and spatial dimensions + if self.media_time_embs is not None: + x = x + self.media_time_embs[:T] + + # blocks + latents = repeat(self.latents, 'n d -> b T n d', b=b, T=T) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + return self.norm(latents) + + +class MaskedCrossAttention(nn.Module): + """Masked cross attention layers. + + Args: + dim (int): Input text feature dimensions. + dim_visual (int): Input visual feature dimensions. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + only_attend_immediate_media (bool): Whether attend immediate media. + Defaults to True. + """ + + def __init__( + self, + *, + dim: int, + dim_visual: int, + dim_head: int = 64, + heads: int = 8, + only_attend_immediate_media: bool = True, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # whether for text to only attend to immediate preceding image + # or all previous images + self.only_attend_immediate_media = only_attend_immediate_media + + def forward(self, + x: torch.Tensor, + media: torch.Tensor, + media_locations: Optional[torch.Tensor] = None, + attend_previous: bool = True): + """Forward function for perceiver sampler. + + Args: + x (torch.Tensor): text features of shape (B, T_txt, D_txt). + media (torch.Tensor): image features of shape + (B, T_img, n, D_img) where n is the dim of the latents. + media_locations (torch.Tensor, optional): boolean mask identifying + the media tokens in x of shape (B, T_txt). Defaults to None. + attend_previous (bool): If false, ignores immediately preceding + image and starts attending when following image. + Defaults to True. + """ + _, T_img, n = media.shape[:3] + h = self.heads + + x = self.norm(x) + + q = self.to_q(x) + media = rearrange(media, 'b t n d -> b (t n) d') + + k, v = self.to_kv(media).chunk(2, dim=-1) + q = rearrange(q, 'b n (h d) -> b h n d', h=h) + k = rearrange(k, 'b n (h d) -> b h n d', h=h) + v = rearrange(v, 'b n (h d) -> b h n d', h=h) + + q = q * self.scale + + sim = einsum('... i d, ... j d -> ... i j', q, k) + + if media_locations is not None: + # at each boolean of True, increment the time counter + # (relative to media time) + text_time = media_locations.cumsum(dim=-1) + media_time = torch.arange(T_img, device=x.device) + 1 + + if not attend_previous: + text_time[~media_locations] += 1 + # make sure max is still the number of images in the sequence + text_time[text_time > repeat( + torch.count_nonzero(media_locations, dim=1), + 'b -> b i', + i=text_time.shape[1], + )] = 0 + + # text time must equal media time if only attending to most + # immediate image otherwise, as long as text time is greater than + # media time (if attending to all previous images / media) + mask_op = torch.eq if self.only_attend_immediate_media else torch.ge # noqa + + text_to_media_mask = mask_op( + rearrange(text_time, 'b i -> b 1 i 1'), + repeat(media_time, 'j -> 1 1 1 (j n)', n=n), + ) + sim = sim.masked_fill(~text_to_media_mask, + -torch.finfo(sim.dtype).max) + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + if media_locations is not None and self.only_attend_immediate_media: + # any text without a preceding media needs to have + # attention zeroed out + text_without_media_mask = text_time == 0 + text_without_media_mask = rearrange(text_without_media_mask, + 'b i -> b 1 i 1') + attn = attn.masked_fill(text_without_media_mask, 0.0) + + out = einsum('... i j, ... j d -> ... i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class GatedCrossAttentionBlock(nn.Module): + """Gated cross attention layers. + + Args: + dim (int): Input text feature dimensions. + dim_visual (int): Input visual feature dimensions. + dim_head (int): Number of dimension heads. Defaults to 64. + heads (int): Number of heads. Defaults to 8. + ff_mult (int): Feed forward multiplier. Defaults to 4. + only_attend_immediate_media (bool): Whether attend immediate media. + Defaults to True. + """ + + def __init__( + self, + *, + dim: int, + dim_visual: int, + dim_head: int = 64, + heads: int = 8, + ff_mult: int = 4, + only_attend_immediate_media: bool = True, + ): + super().__init__() + self.attn = MaskedCrossAttention( + dim=dim, + dim_visual=dim_visual, + dim_head=dim_head, + heads=heads, + only_attend_immediate_media=only_attend_immediate_media, + ) + self.attn_gate = nn.Parameter(torch.tensor([0.0])) + + self.ff = FeedForward(dim, mult=ff_mult) + self.ff_gate = nn.Parameter(torch.tensor([0.0])) + + def forward(self, + x: torch.Tensor, + media: torch.Tensor, + media_locations: Optional[torch.Tensor] = None, + attend_previous: bool = True): + """Forward function for perceiver sampler. + + Args: + x (torch.Tensor): text features of shape (B, T_txt, D_txt). + media (torch.Tensor): image features of shape + (B, T_img, n, D_img) where n is the dim of the latents. + media_locations (torch.Tensor, optional): boolean mask identifying + the media tokens in x of shape (B, T_txt). Defaults to None. + attend_previous (bool): If false, ignores immediately preceding + image and starts attending when following image. + Defaults to True. + """ + x = ( + self.attn( + x, + media, + media_locations=media_locations, + attend_previous=attend_previous, + ) * self.attn_gate.tanh() + x) + x = self.ff(x) * self.ff_gate.tanh() + x + + return x + + +class FlamingoLayer(nn.Module): + """Faminogo layers. + + Args: + gated_cross_attn_layer (nn.Module): Gated cross attention layer. + decoder_layer (nn.Module): Decoder layer. + """ + + def __init__(self, gated_cross_attn_layer: nn.Module, + decoder_layer: nn.Module): + super().__init__() + self.gated_cross_attn_layer = gated_cross_attn_layer + self.decoder_layer = decoder_layer + self.vis_x = None + self.media_locations = None + + def is_conditioned(self) -> bool: + """Check whether the layer is conditioned.""" + return self.vis_x is not None + + def condition_vis_x(self, vis_x): + """Set condition vision features.""" + self.vis_x = vis_x + + def condition_media_locations(self, media_locations): + """Set condition media locations.""" + self.media_locations = media_locations + + def condition_attend_previous(self, attend_previous): + """Set attend previous.""" + self.attend_previous = attend_previous + + def forward( + self, + lang_x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **decoder_layer_kwargs, + ): + """Forward function. + + Args: + lang_x (torch.Tensor): language inputs. + attention_mask (torch.Tensor, optional): text attention mask. + Defaults to None. + **decoder_layer_kwargs: Other decoder layer keyword arguments. + """ + if self.gated_cross_attn_layer is None: + return self.decoder_layer( + lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) + + if self.vis_x is None: + raise ValueError('vis_x must be conditioned before forward pass') + + if self.media_locations is None: + raise ValueError( + 'media_locations must be conditioned before forward pass') + + lang_x = self.gated_cross_attn_layer( + lang_x, + self.vis_x, + media_locations=self.media_locations, + attend_previous=self.attend_previous, + ) + lang_x = self.decoder_layer( + lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) + return lang_x diff --git a/mmpretrain/models/multimodal/flamingo/utils.py b/mmpretrain/models/multimodal/flamingo/utils.py new file mode 100644 index 0000000..1077e14 --- /dev/null +++ b/mmpretrain/models/multimodal/flamingo/utils.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Type + +from mmpretrain.registry import MODELS + + +class ExtendModule: + """Combine the base language model with adapter. This module will create a + instance from base with extended functions in adapter. + + Args: + base (object): Base module could be any object that represent + a instance of language model or a dict that can build the + base module. + adapter: (dict): Dict to build the adapter. + """ + + def __new__(cls, base: object, adapter: dict): + + if isinstance(base, dict): + base = MODELS.build(base) + + adapter_module = MODELS.get(adapter.pop('type')) + cls.extend_instance(base, adapter_module) + return adapter_module.extend_init(base, **adapter) + + @classmethod + def extend_instance(cls, base: object, mixin: Type[Any]): + """Apply mixins to a class instance after creation. + + Args: + base (object): Base module instance. + mixin: (Type[Any]): Adapter class type to mixin. + """ + base_cls = base.__class__ + base_cls_name = base.__class__.__name__ + base.__class__ = type( + base_cls_name, (mixin, base_cls), + {}) # mixin needs to go first for our forward() logic to work + + +def getattr_recursive(obj, att): + """ + Return nested attribute of obj + Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c + """ + if att == '': + return obj + i = att.find('.') + if i < 0: + return getattr(obj, att) + else: + return getattr_recursive(getattr(obj, att[:i]), att[i + 1:]) + + +def setattr_recursive(obj, att, val): + """ + Set nested attribute of obj + Example: setattr_recursive(obj, 'a.b.c', val) + is equivalent to obj.a.b.c = val + """ + if '.' in att: + obj = getattr_recursive(obj, '.'.join(att.split('.')[:-1])) + setattr(obj, att.split('.')[-1], val) diff --git a/mmpretrain/models/multimodal/llava/__init__.py b/mmpretrain/models/multimodal/llava/__init__.py new file mode 100644 index 0000000..aef10d3 --- /dev/null +++ b/mmpretrain/models/multimodal/llava/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .llava import Llava +from .modules import LlavaLlamaForCausalLM + +__all__ = ['Llava', 'LlavaLlamaForCausalLM'] diff --git a/mmpretrain/models/multimodal/llava/llava.py b/mmpretrain/models/multimodal/llava/llava.py new file mode 100644 index 0000000..f829b09 --- /dev/null +++ b/mmpretrain/models/multimodal/llava/llava.py @@ -0,0 +1,267 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import List, Optional + +import torch +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from ...utils import no_load_hf_pretrained_model +from .modules import LlavaLlamaForCausalLM + + +@MODELS.register_module() +class Llava(BaseModel): + """The LLaVA model for multiple tasks. + + Args: + vision_encoder (dict): The config of the vision encoder. + lang_encoder (dict): The config of the language encoder. + tokenizer (dict): The tokenizer to encode the text. + prompt_tmpl (str): Prompt template for inference. + task (int): The task to perform prediction. + use_im_start_end (bool): Whether to use the im_start and im_end tokens + mm_vision_select_layer (int): The index from vision encoder output. + Defaults to -1. + mm_proj_depth (int): The number of linear layers for multi-modal + projection. Defaults to 1. + load_lang_pretrained (bool): Whether to load the pretrained model of + language encoder. Defaults to False. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of [~`transformers.GenerationConfig`]. + Defaults to an empty dict. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + support_tasks = {'caption', 'vqa'} + im_patch_token = '' + im_start_token = '' + im_end_token = '' + + def __init__(self, + vision_encoder: dict, + lang_encoder: dict, + tokenizer: dict, + mm_hidden_size: int, + prompt_tmpl: str, + task: str = 'caption', + use_im_patch: bool = True, + use_im_start_end: bool = False, + mm_vision_select_layer: int = -1, + mm_proj_depth: int = 1, + generation_cfg: dict = dict(), + load_lang_pretrained: bool = False, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + self.task = task + + # init tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + # add Llava special tokens to the tokenizer + if use_im_patch: + self.tokenizer.add_tokens([self.im_patch_token], + special_tokens=True) + if use_im_start_end: + self.tokenizer.add_tokens([self.im_start_token, self.im_end_token], + special_tokens=True) + + # Template to format the prompt input + self.prompt_tmpl = prompt_tmpl + + # init vision encoder related modules + vision_encoder_weight = vision_encoder.pop('pretrained', None) + vision_encoder = MODELS.build(vision_encoder) + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint( + vision_encoder, + vision_encoder_weight, + map_location='cpu', + revise_keys=[(r'^backbone\.', '')], + ) + vision_encoder.is_init = True + + # init language encoder related modules + if load_lang_pretrained: + lang_encoder = MODELS.build(lang_encoder) + else: + with no_load_hf_pretrained_model(): + lang_encoder = MODELS.build(lang_encoder) + lang_encoder.resize_token_embeddings(len(self.tokenizer)) + + self.model = LlavaLlamaForCausalLM( + vision_encoder=vision_encoder, + lang_encoder=lang_encoder, + mm_hidden_size=mm_hidden_size, + mm_proj_depth=mm_proj_depth, + use_im_start_end=use_im_start_end, + im_start_token=self.tokenizer.convert_tokens_to_ids( + self.im_start_token), + im_end_token=self.tokenizer.convert_tokens_to_ids( + self.im_end_token), + mm_vision_select_layer=mm_vision_select_layer) + + self.generation_cfg = generation_cfg + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_ckpt_hook) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'loss', + ): + """The unified entry for a forward process in both training and test. + + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + images (torch.Tensor): The input image tensor with different ndim + according to the inputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'loss'. + + Returns: + The return type depends on ``mode``. + - If ``mode="loss"``, return a dict of tensor. + """ + + if mode == 'predict': + return self.predict(images, data_samples) + elif mode == 'loss': + raise NotImplementedError + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict(self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + **generation_cfg): + """Predict generation results from a batch of inputs. + + Args: + images (torch.Tensor): For zero-shot, the input images tensor is + with shape (B, C, H, W), for few-shot, which is + (B, T_img, C, H, W) in general. Images in the same chunk + are collated along T_img. Video data is not supported yet. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **generation_cfg: Other keyword arguments accepted by the + ``generate`` method of :attr:`lang_encoder`. + + Returns: + List[DataSample]: Return list of data samples. + """ + # generation_cfg in prediction should be dominant + generation_cfg = {**self.generation_cfg, **generation_cfg} + + input_text = self.preprocess_text(data_samples, device=images.device) + + outputs = self.model.generate( + input_text.input_ids, + attention_mask=input_text.attention_mask, + eos_token_id=self.tokenizer.eos_token_id, + images=images, + **generation_cfg) + + # remove prefix + outputs = outputs[:, len(input_text.input_ids[0]):] + + return self.post_process(outputs, data_samples) + + def preprocess_text(self, data_samples: List[DataSample], + device: torch.device) -> List[DataSample]: + """Preprocess text in advance before fed into language model. + + Args: + data_samples (List[DataSample]): The annotation + data of every samples. Defaults to None. + device (torch.device): Device for text to put on. + + Returns: + List[DataSample]: Return list of data samples. + """ + tokens = [] + for sample in data_samples: + prompt = self.prompt_tmpl.format(**sample.to_dict()) + input_ids = [] + while '' in prompt: + prefix, _, prompt = prompt.partition('') + input_ids.extend( + self.tokenizer(prefix, add_special_tokens=False).input_ids) + input_ids.append(-200) + if prompt: + input_ids.extend( + self.tokenizer(prompt, add_special_tokens=False).input_ids) + tokens.append(dict(input_ids=input_ids)) + + self.tokenizer.padding_side = 'left' + input_text = self.tokenizer.pad( + tokens, + padding='longest', + return_tensors='pt', + max_length=2000, + ).to(device) + return input_text + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + # remove text pattern + if self.task == 'caption': + data_sample.pred_caption = output + elif self.task == 'vqa': + data_sample.pred_answer = output + + return data_samples + + @staticmethod + def _load_ckpt_hook(module, incompatible_keys): + """Avoid warning missing keys except lang_encoder keys.""" + for key in list(incompatible_keys.missing_keys): + if re.match('model.vision_tower', key): + incompatible_keys.missing_keys.remove(key) diff --git a/mmpretrain/models/multimodal/llava/modules.py b/mmpretrain/models/multimodal/llava/modules.py new file mode 100644 index 0000000..fa3c6bb --- /dev/null +++ b/mmpretrain/models/multimodal/llava/modules.py @@ -0,0 +1,234 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from transformers import PreTrainedModel + +DEFAULT_IMAGE_TOKEN = '' +DEFAULT_IMAGE_PATCH_TOKEN = '' +DEFAULT_IM_START_TOKEN = '' +DEFAULT_IM_END_TOKEN = '' + + +class LlavaLlamaForCausalLM(PreTrainedModel): + + def __init__(self, + vision_encoder, + lang_encoder, + mm_hidden_size, + use_im_start_end=True, + mm_proj_depth=1, + im_start_token: Optional[int] = None, + im_end_token: Optional[int] = None, + im_token_index: int = -200, + mm_vision_select_layer: int = -1): + super().__init__(lang_encoder.config) + self.vision_tower = vision_encoder + self.lang_encoder = lang_encoder + + self.use_im_start_end = use_im_start_end + self.im_start_token = im_start_token + self.im_end_token = im_end_token + self.mm_hidden_size = mm_hidden_size + self.mm_vision_select_layer = mm_vision_select_layer + self.im_token_index = im_token_index + self.lang_hidden_size = lang_encoder.config.hidden_size + + if mm_proj_depth == 1: + # Llava V1 + mm_projector = nn.Linear(self.mm_hidden_size, + self.lang_hidden_size) + self.lang_encoder.model.add_module('mm_projector', mm_projector) + elif mm_proj_depth > 1: + # Llava V1.5 + modules = [nn.Linear(self.mm_hidden_size, self.lang_hidden_size)] + for _ in range(1, mm_proj_depth): + modules.append(nn.GELU()) + modules.append( + nn.Linear(self.lang_hidden_size, self.lang_hidden_size)) + mm_projector = nn.Sequential(*modules) + self.lang_encoder.model.add_module('mm_projector', mm_projector) + elif mm_proj_depth == 0: + self.lang_encoder.model.add_module('mm_projector', nn.Identity()) + + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + (input_ids, attention_mask, past_key_values, inputs_embeds, + labels) = self.forward_vision_tower(input_ids, attention_mask, + past_key_values, labels, images) + + return self.lang_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use + # them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + 'images': kwargs.get('images', None), + }) + return model_inputs + + def forward_vision_tower( + self, + input_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + past_key_values: torch.FloatTensor, + labels: torch.LongTensor, + images: Union[torch.FloatTensor, None] = None, + ): + if self.vision_tower is None or images is None or input_ids.shape[ + 1] == 1: + if (past_key_values is not None and self.vision_tower is not None + and images is not None and input_ids.shape[1] == 1): + attention_mask = torch.ones( + (attention_mask.shape[0], + past_key_values[-1][-1].shape[-2] + 1), + dtype=attention_mask.dtype, + device=attention_mask.device) + return input_ids, attention_mask, past_key_values, None, labels + + with torch.no_grad(): + # TODO: support variable number of images (single now) + feats = self.vision_tower(images) + image_features = feats[-1][:, 1:] + + image_features = self.lang_encoder.model.mm_projector(image_features) + + new_input_embeds = [] + new_labels = [] if labels is not None else None + new_attn_mask = [] if attention_mask is not None else None + for batch_idx, cur_input_ids in enumerate(input_ids): + cur_img = image_features[batch_idx] + + if (cur_input_ids != self.im_token_index).all(): + # multimodal LLM, but the current sample is not multimodal + new_input_embeds.append(self.embed_tokens(cur_input_ids)) + if labels is not None: + new_labels.append(labels[batch_idx]) + if attention_mask is not None: + new_attn_mask.append(attention_mask[batch_idx]) + continue + + img_idx = torch.where(cur_input_ids == self.im_token_index)[0][0] + if self.use_im_start_end: + cur_new_input_embeds = torch.cat( + [ + self.embed_tokens(cur_input_ids[:img_idx - 1]), + self.embed_tokens(cur_input_ids[img_idx - 1:img_idx]), + cur_img, + self.embed_tokens( + cur_input_ids[img_idx + 1:img_idx + 2]), + self.embed_tokens(cur_input_ids[img_idx + 2:]), + ], + dim=0, + ) + else: + cur_new_input_embeds = torch.cat( + [ + self.embed_tokens(cur_input_ids[:img_idx]), + cur_img, + self.embed_tokens(cur_input_ids[img_idx + 1:]), + ], + dim=0, + ) + new_input_embeds.append(cur_new_input_embeds) + + if labels is not None: + cur_new_labels = torch.cat([ + labels[batch_idx, :img_idx], + labels.new_full((cur_img.size(0), ), -100), + labels[batch_idx, img_idx + 1:], + ], + dim=0) + new_labels.append(cur_new_labels) + + if attention_mask is not None: + cur_attn_mask = torch.cat([ + attention_mask[batch_idx, :img_idx], + attention_mask.new_full((cur_img.size(0), ), True), + attention_mask[batch_idx, img_idx + 1:], + ], + dim=0) + new_attn_mask.append(cur_attn_mask) + + inputs_embeds = torch.stack(new_input_embeds, dim=0) + if labels is not None: + labels = torch.stack(new_labels, dim=0) + if attention_mask is not None: + attention_mask = torch.stack(new_attn_mask, dim=0) + + return None, attention_mask, past_key_values, inputs_embeds, labels + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + def embed_tokens(self, input_ids): + return self.lang_encoder.model.embed_tokens(input_ids) diff --git a/mmpretrain/models/multimodal/minigpt4/__init__.py b/mmpretrain/models/multimodal/minigpt4/__init__.py new file mode 100644 index 0000000..5358bb1 --- /dev/null +++ b/mmpretrain/models/multimodal/minigpt4/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .minigpt4 import MiniGPT4 + +__all__ = ['MiniGPT4'] diff --git a/mmpretrain/models/multimodal/minigpt4/minigpt4.py b/mmpretrain/models/multimodal/minigpt4/minigpt4.py new file mode 100644 index 0000000..d25d0b6 --- /dev/null +++ b/mmpretrain/models/multimodal/minigpt4/minigpt4.py @@ -0,0 +1,410 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +import re +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.logging import MMLogger +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class MiniGPT4(BaseModel): + """The multi-modality model of MiniGPT-4. + + The implementation of `MiniGPT-4 `_. + Modified from https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/models/mini_gpt4.py + + Args: + vision_encoder (dict): The config for vision encoder. + q_former_model (dict): The config for Qformer. + lang_encoder (dict): The config for language model. + tokenizer (dict): The config for tokenizer. + task (str): To define the task, which control the processing of text. + Defaults to 'caption'. + freeze_vit (bool): Freeze the training of ViT. Defaults to True. + freeze_q_former (bool): Freeze the training of Qformer. Defaults to + True. + num_query_token (int): Number of query tokens of Qformer. Defaults to + 32. + prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]) + raw_prompts (dict): Prompts for training. Defaults to dict(). + max_txt_len (int): Max token length while doing tokenization. Defaults + to 32. + end_sym (str): Ended symbol of the sequence. Defaults to '###'. + generation_cfg (dict): The config of text generation. Defaults to + dict(). + data_preprocessor (:obj:`BaseDataPreprocessor`): Used for + pre-processing data sampled by dataloader to the format accepted by + :meth:`forward`. Defaults to None. + init_cfg (dict): Initialization config dict. Defaults to None. + """ # noqa + + def __init__(self, + vision_encoder: dict, + q_former_model: dict, + lang_encoder: dict, + tokenizer: dict, + task: str = 'caption', + freeze_vit: bool = True, + freeze_q_former: bool = True, + num_query_token: int = 32, + prompt_template: dict = dict([('en', + '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]), + raw_prompts: dict = dict(), + max_txt_len: int = 32, + end_sym: str = '###', + generation_cfg: dict = dict(), + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.task = task + logger = MMLogger.get_current_instance() + + # build vision model + vision_encoder_weight = vision_encoder.pop('pretrained', None) + self.vision_encoder = MODELS.build(vision_encoder) + self.ln_vision = nn.LayerNorm(self.vision_encoder.embed_dims) + + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint(self.vision_encoder, vision_encoder_weight) + self.vision_encoder.is_init = True + if freeze_vit: + for name, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + else: + logger.warning('Please check `frozen_stages` in the dict of' + '`vision_encoder`. Also set it to be -1 if do not' + 'freeze ViT.') + + # build Qformer + q_former_model_weight = q_former_model.pop('pretrained', None) + self.q_former = MODELS.build(q_former_model) + self.q_former.cls = None + self.q_former.bert.embeddings.word_embeddings = None + self.q_former.bert.embeddings.position_embeddings = None + for layer in self.q_former.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + self.query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, self.q_former.config.hidden_size)) + self.query_tokens.data.normal_( + mean=0.0, std=self.q_former.config.initializer_range) + + if q_former_model_weight is not None: + from mmengine.runner.checkpoint import CheckpointLoader + state_dict = CheckpointLoader.load_checkpoint( + q_former_model_weight)['state_dict'] + self.load_state_dict(state_dict, strict=False) + # The ln_vision weights are also in the q-former checkpoint. + setattr(self.ln_vision, 'is_init', True) + setattr(self.q_former, 'is_init', True) + + if freeze_q_former: + for name, param in self.q_former.named_parameters(): + param.requires_grad = False + self.q_former.eval() + self.query_tokens.requires_grad = False + + # build language model + self.llama_tokenizer = TOKENIZER.build(tokenizer) + self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token + + self.llama_model = MODELS.build(lang_encoder) + for name, param in self.llama_model.named_parameters(): + param.requires_grad = False + + # build linear projection layer + self.llama_proj = nn.Linear(self.q_former.config.hidden_size, + self.llama_model.config.hidden_size) + self.max_txt_len = max_txt_len + self.end_sym = end_sym + self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1] + + # set prompts + self.en_prompt_list, self.zh_prompt_list = [], [] + if raw_prompts.get('en') is not None: + en_filted_prompts = [ + raw_prompt for raw_prompt in raw_prompts['en'] + if '' in raw_prompt + ] + self.en_prompt_list = [ + prompt_template['en'].format(p) for p in en_filted_prompts + ] + if raw_prompts.get('zh') is not None: + zh_filted_prompts = [ + raw_prompt for raw_prompt in raw_prompts['zh'] + if '' in raw_prompt + ] + self.zh_prompt_list = [ + prompt_template['zh'].format(p) for p in zh_filted_prompts + ] + + # update generation configs + self.generation_cfg = dict( + max_new_tokens=300, + num_beams=1, + do_sample=True, + min_length=1, + top_p=0.9, + repetition_penalty=1.1, + length_penalty=1.0, + temperature=1.0) + self.generation_cfg.update(**generation_cfg) + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_llama_proj_hook) + + def half(self): + self.llama_model = self.llama_model.half() + return self + + def encode_img(self, + images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """The function to encode the images.""" + device = images.device + x = self.vision_encoder(images)[0] + image_embeds = self.ln_vision(x).to(device) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.q_former.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_llama = self.llama_proj(query_output.last_hidden_state) + atts_llama = torch.ones( + inputs_llama.size()[:-1], dtype=torch.long).to(images.device) + return inputs_llama, atts_llama + + def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor, + prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + """The function to wrap the image and prompt. + + Make sure that len(prompt) == img_embeds.shape[0]. + + Args: + img_embeds (torch.Tensor): The embedding of the input images. + atts_img (torch.Tensor): Attention map of the image embeddings. + prompt (List[str]): The prompt of the batch data. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map. + """ + if len(prompt) > 0: + p_before_list, p_after_list = [], [] + for pro in prompt: + p_before, p_after = pro.split('') + p_before_list.append(p_before) + p_after_list.append(p_after) + p_before_tokens = self.llama_tokenizer( + p_before_list, + return_tensors='pt', + padding='longest', + add_special_tokens=False).to(img_embeds.device) + p_after_tokens = self.llama_tokenizer( + p_after_list, + return_tensors='pt', + padding='longest', + add_special_tokens=False).to(img_embeds.device) + p_before_embeds = self.llama_model.model.embed_tokens( + p_before_tokens.input_ids) + p_after_embeds = self.llama_model.model.embed_tokens( + p_after_tokens.input_ids) + wrapped_img_embeds = torch.cat( + [p_before_embeds, img_embeds, p_after_embeds], dim=1) + wrapped_atts_img = atts_img[:, :1].expand( + -1, wrapped_img_embeds.shape[1]) + return wrapped_img_embeds, wrapped_atts_img + else: + return img_embeds, atts_img + + def loss(self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None) -> dict: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + img_embeds, atts_img = self.encode_img(images) + + self.llama_tokenizer.padding_side = 'right' + + prompts, texts = [], [] + for t in data_samples: + chat_content = t.chat_content + split_mark = '###Answer: ' if t.lang == 'en' else '###答:' + prompt, text = chat_content.split(split_mark) + prompt += split_mark + text += self.end_sym + prompts.append(prompt) + texts.append(text) + + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) + + to_regress_tokens = self.llama_tokenizer( + texts, + return_tensors='pt', + padding='longest', + truncation=True, + max_length=self.max_txt_len, + add_special_tokens=False).to(images.device) + + targets = to_regress_tokens.input_ids.masked_fill( + to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, + -100) + + empty_targets = ( + torch.ones([atts_img.shape[0], atts_img.shape[1] + 1], + dtype=torch.long).to(images.device).fill_( + -100) # plus one for bos + ) + targets = torch.cat([empty_targets, targets], dim=1) + + batch_size = img_embeds.shape[0] + bos = torch.ones([batch_size, 1], + dtype=to_regress_tokens.input_ids.dtype, + device=to_regress_tokens.input_ids.device + ) * self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens(bos) + atts_bos = atts_img[:, :1] + + to_regress_embeds = self.llama_model.model.embed_tokens( + to_regress_tokens.input_ids) + inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], + dim=1) + attention_mask = torch.cat( + [atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) + + outputs = self.llama_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + return dict(loss=loss) + + def predict( + self, + images: torch.Tensor, + data_samples: Optional[List[DataSample]] = None + ) -> List[DataSample]: + + with torch.no_grad(): + img_embeds, atts_img = self.encode_img(images) + + prompts = [ + random.choice(self.zh_prompt_list) if hasattr(t, 'lang') + and t.lang == 'zh' else random.choice(self.en_prompt_list) + for t in data_samples + ] + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) + + batch_size = img_embeds.shape[0] + bos = torch.ones( + [batch_size, 1], dtype=torch.long, + device=img_embeds.device) * self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens(bos) + inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1) + + outputs = self.llama_model.generate( + inputs_embeds=inputs_embeds, + eos_token_id=self.end_token_id, + **self.generation_cfg) + + return self.post_process(outputs, data_samples) + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.llama_tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + if self.task == 'caption': + output = output.split('###')[0] + data_sample.pred_caption = output + else: + # raw output + data_sample.pred_output = output + return data_samples + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'loss': + return self.loss(images, data_samples) + elif mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + @staticmethod + def _load_llama_proj_hook(module, incompatible_keys): + """Avoid warning missing keys except LLaMA projection keys.""" + proj_patterns = [ + 'vision_encoder.*', + 'ln_vision.*', + 'q_former.*', + 'query_tokens', + 'llama_model.*', + ] + for key in list(incompatible_keys.missing_keys): + if any(re.match(pattern, key) for pattern in proj_patterns): + incompatible_keys.missing_keys.remove(key) diff --git a/mmpretrain/models/multimodal/ofa/__init__.py b/mmpretrain/models/multimodal/ofa/__init__.py new file mode 100644 index 0000000..bcb3f45 --- /dev/null +++ b/mmpretrain/models/multimodal/ofa/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ofa import OFA +from .ofa_modules import OFADecoder, OFAEncoder, OFAEncoderDecoder + +__all__ = ['OFAEncoderDecoder', 'OFA', 'OFAEncoder', 'OFADecoder'] diff --git a/mmpretrain/models/multimodal/ofa/ofa.py b/mmpretrain/models/multimodal/ofa/ofa.py new file mode 100644 index 0000000..e15787a --- /dev/null +++ b/mmpretrain/models/multimodal/ofa/ofa.py @@ -0,0 +1,320 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import string +from collections import defaultdict +from functools import partial +from typing import Optional, Union + +import mmengine +import torch +from mmengine.model import BaseModel + +from mmpretrain.datasets import CleanCaption +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from .ofa_modules import OFAEncoderDecoder + + +class TreeNode(): + + def __init__(self): + self.child = defaultdict(TreeNode) + + +class Trie: + + def __init__(self, eos): + self.root = TreeNode() + self.eos = eos + + def insert(self, word): + cur = self.root + for c in word: + cur = cur.child[c] + + def get_next_layer(self, word): + cur = self.root + for c in word: + cur = cur.child.get(c) + if cur is None: + return [self.eos] + return list(cur.child.keys()) + + +def apply_constraint( + input_ids: torch.Tensor, + logits: torch.Tensor, + decoder_prompts: Optional[list], + num_beams: int, + constraint_trie: Trie = None, +): + if decoder_prompts is None and constraint_trie is None: + return logits + + mask = logits.new_zeros(logits[:, -1, :].size(), dtype=torch.bool) + input_ids = input_ids.view(-1, num_beams, input_ids.shape[-1]) + for batch_id, beam_sent in enumerate(input_ids): + for beam_id, sent in enumerate(beam_sent): + if decoder_prompts is None: + prompt_len = 0 + else: + prompt_len = len(decoder_prompts[batch_id]) + + if sent.size(0) - 1 < prompt_len: + allowed_tokens = [decoder_prompts[batch_id][sent.size(0) - 1]] + mask[batch_id * num_beams + beam_id, allowed_tokens] = True + elif constraint_trie is not None: + answer_tokens = [0] + sent[prompt_len + 1:].tolist() + allowed_tokens = constraint_trie.get_next_layer(answer_tokens) + mask[batch_id * num_beams + beam_id, allowed_tokens] = True + else: + mask[batch_id * num_beams + beam_id, :] = True + logits[:, -1, :].masked_fill_(~mask, float('-inf')) + return logits + + +@MODELS.register_module() +class OFA(BaseModel): + """The OFA model for multiple tasks. + + Args: + encoder_cfg (dict): The config of the encoder, accept the keyword + arguments of :class:`OFAEncoder`. + decoder_cfg (dict): The config of the decoder, accept the keyword + arguments of :class:`OFADecoder`. + vocab_size (int): The size of the vocabulary. + embedding_dim (int): The embedding dimensions of both the encoder + and the decoder. + tokenizer (dict | PreTrainedTokenizer): The tokenizer to encode + the text. + task (str): The task name, supported tasks are "caption", "vqa" and + "refcoco". + prompt (str, optional): The prompt template for the following tasks, + If None, use default prompt: + + - **caption**: ' what does the image describe?' + - **refcoco**: ' which region does the text " {} " describe?' + + Defaults to None + ans2label (str | Sequence | None): The answer to label mapping for + the vqa task. If a string, it should be a pickle or json file. + The sequence constrains the output answers. Defaults to None, + which means no constraint. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of :class:`~transformers.GenerationConfig`. + Defaults to an empty dict. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. See :class: + `MultiModalDataPreprocessor` for more details. Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + support_tasks = {'caption', 'vqa', 'refcoco'} + + def __init__( + self, + encoder_cfg, + decoder_cfg, + vocab_size, + embedding_dim, + tokenizer, + task, + prompt=None, + ans2label: Union[dict, str, None] = None, + generation_cfg=dict(), + data_preprocessor: Optional[dict] = None, + init_cfg=None, + ): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if isinstance(tokenizer, dict): + self.tokenizer = TOKENIZER.build(tokenizer) + else: + self.tokenizer = tokenizer + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + + self.prompt = prompt + self.task = task + + if isinstance(ans2label, str): + self.ans2label = mmengine.load(ans2label) + else: + self.ans2label = ans2label + + if self.task == 'vqa' and self.ans2label is not None: + self.constraint_trie = Trie(eos=self.tokenizer.eos_token_id) + answers = [f' {answer}' for answer in self.ans2label] + answer_tokens = self.tokenizer(answers, padding=False) + for answer_token in answer_tokens['input_ids']: + self.constraint_trie.insert(answer_token) + else: + self.constraint_trie = None + + generation_cfg = { + 'num_beams': 5, + 'max_new_tokens': 20, + 'no_repeat_ngram_size': 3, + **generation_cfg, + } + self.model = OFAEncoderDecoder( + encoder_cfg=encoder_cfg, + decoder_cfg=decoder_cfg, + padding_idx=self.tokenizer.pad_token_id, + vocab_size=vocab_size, + embedding_dim=embedding_dim, + generation_cfg=generation_cfg, + ) + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def predict( + self, + images, + data_samples=None, + post_process=True, + **generation_config, + ): + text_tokens = self.preprocess_text(data_samples, images.size(0), + images.device) + + if 'images_mask' in data_samples[0]: + images_mask = torch.tensor([ + sample.get('images_mask') for sample in data_samples + ]).bool().to(images.device) + else: + images_mask = None + + num_beams = generation_config.get( + 'num_beams', getattr(self.model.generation_config, 'num_beams')) + decoder_prompts = self.get_decoder_prompts(data_samples) + constrain_fn = partial( + apply_constraint, + constraint_trie=self.constraint_trie, + decoder_prompts=decoder_prompts, + num_beams=num_beams, + ) + + outputs = self.model.generate( + input_ids=text_tokens, + images=images, + images_mask=images_mask, + constrain_fn=constrain_fn, + **generation_config, + ) + + if decoder_prompts is not None: + # Remove the prefix decoder prompt. + for prompt_ids, token in zip(decoder_prompts, outputs): + token[1:len(prompt_ids) + 1] = self.tokenizer.pad_token_id + + if post_process: + return self.post_process(outputs, data_samples) + else: + return outputs + + def get_decoder_prompts(self, data_samples): + decoder_prompts = [] + if 'decoder_prompt' not in data_samples[0]: + return None + for sample in data_samples: + prompt = ' ' + sample.get('decoder_prompt') + prompt_ids = self.tokenizer(prompt, add_special_tokens=False) + prompt_ids = prompt_ids['input_ids'] + decoder_prompts.append(prompt_ids) + return decoder_prompts + + def preprocess_text(self, data_samples, batch_size, device): + if self.task == 'caption': + prompt = self.prompt or ' what does the image describe?' + prompts = [prompt] * batch_size + prompts = self.tokenizer(prompts, return_tensors='pt') + return prompts.input_ids.to(device) + elif self.task == 'vqa': + prompts = [] + for sample in data_samples: + assert 'question' in sample + prompt = ' ' + sample.get('question') + prompts.append(prompt) + prompts = self.tokenizer( + prompts, return_tensors='pt', padding=True) + return prompts.input_ids.to(device) + elif self.task == 'refcoco': + prompt_template = self.prompt or \ + ' which region does the text " {} " describe?' + prompts = [] + for sample in data_samples: + assert 'text' in sample + prompt = prompt_template.format(sample.get('text')) + prompts.append(prompt) + prompts = self.tokenizer( + prompts, return_tensors='pt', padding=True) + return prompts.input_ids.to(device) + + def post_process(self, outputs, data_samples): + + out_data_samples = [] + if data_samples is None: + data_samples = [None] * outputs.size(0) + + for data_sample, token in zip(data_samples, outputs): + if data_sample is None: + data_sample = DataSample() + + if self.task == 'caption': + text = self.tokenizer.decode(token, skip_special_tokens=True) + text = CleanCaption( + lowercase=False, + remove_chars=string.punctuation).clean(text) + data_sample.pred_caption = text + elif self.task == 'vqa': + text = self.tokenizer.decode(token, skip_special_tokens=True) + data_sample.pred_answer = text.strip() + elif self.task == 'refcoco': + bbox = token[1:5] - self.tokenizer.bin_offset + # During training, the bbox is normalized by 512. It's related + # to the `max_image_size` config in the official repo. + bbox = bbox / self.tokenizer.num_bins * 512 + scale_factor = data_sample.get('scale_factor', (1, 1)) + bbox[0::2] /= scale_factor[0] + bbox[1::2] /= scale_factor[1] + data_sample.pred_bboxes = bbox.unsqueeze(0) + if 'gt_bboxes' in data_sample: + gt_bboxes = bbox.new_tensor(data_sample.gt_bboxes) + gt_bboxes[:, 0::2] /= scale_factor[0] + gt_bboxes[:, 1::2] /= scale_factor[1] + data_sample.gt_bboxes = gt_bboxes + out_data_samples.append(data_sample) + + return out_data_samples diff --git a/mmpretrain/models/multimodal/ofa/ofa_modules.py b/mmpretrain/models/multimodal/ofa/ofa_modules.py new file mode 100644 index 0000000..ef5c853 --- /dev/null +++ b/mmpretrain/models/multimodal/ofa/ofa_modules.py @@ -0,0 +1,1613 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from dataclasses import dataclass +from functools import partial +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.utils import digit_version +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, ModelOutput, Seq2SeqLMOutput) +from transformers.modeling_utils import (GenerationConfig, GenerationMixin, + PretrainedConfig) + +from mmpretrain.registry import MODELS +from ...backbones.resnet import Bottleneck, ResNet + +if digit_version(torch.__version__) >= digit_version('1.10.0'): + torch_meshgrid = partial(torch.meshgrid, indexing='ij') +else: + torch_meshgrid = torch.meshgrid + + +def make_token_bucket_position(bucket_size, max_position=1024): + context_pos = torch.arange(max_position, dtype=torch.long)[:, None] + memory_pos = torch.arange(max_position, dtype=torch.long)[None, :] + relative_pos = context_pos - memory_pos + sign = torch.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), + mid - 1, torch.abs(relative_pos)) + log_pos = torch.ceil( + torch.log(abs_pos / mid) / math.log( + (max_position - 1) / mid) * (mid - 1)) + mid + log_pos = log_pos.int() + bucket_pos = torch.where(abs_pos.le(mid), relative_pos, + log_pos * sign).long() + return bucket_pos + bucket_size - 1 + + +def make_image_bucket_position(bucket_size, num_relative_distance): + coords_h = torch.arange(bucket_size) + coords_w = torch.arange(bucket_size) + # (2, h, w) + coords = torch.stack(torch_meshgrid([coords_h, coords_w])) + # (2, h*w) + coords_flatten = torch.flatten(coords, 1) + # (2, h*w, h*w) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + # (h*w, h*w, 2) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += bucket_size - 1 + relative_coords[:, :, 0] *= 2 * bucket_size - 1 + relative_position_index = torch.zeros( + size=(bucket_size * bucket_size + 1, ) * 2, + dtype=relative_coords.dtype) + # (h*w, h*w) + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + return relative_position_index + + +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + past_key_values_length: int = 0): + """Make causal mask used for uni-directional self-attention.""" + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float('-inf')) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], + dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, + tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + """Expands attention_mask from ``[B, L_s]`` to ``[B, 1, L_t, L_s]``. + + Where ``B`` is batch_size, `L_s`` is the source sequence length, and + ``L_t`` is the target sequence length. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + return expanded_mask.masked_fill(expanded_mask.bool(), + torch.finfo(dtype).min) + + +class MultiheadAttention(BaseModule): + """Multi-head Attention Module for OFA. + + Args: + embedding_dim (int): The embedding dimension of query. + num_heads (int): Parallel attention heads. + kdim (int, optional): The embedding dimension of key. + Defaults to None, which means the same as the `embedding_dim`. + vdim (int, optional): The embedding dimension of value. + Defaults to None, which means the same as the `embedding_dim`. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + scale_factor (float): The scale of qk will be + ``(head_dim * scale_factor) ** -0.5``. Defaults to 1. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embedding_dim, + num_heads, + kdim=None, + vdim=None, + attn_drop=0., + scale_factor=1., + qkv_bias=True, + proj_bias=True, + scale_heads=False, + init_cfg=None): + super(MultiheadAttention, self).__init__(init_cfg=init_cfg) + + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.kdim = kdim or embedding_dim + self.vdim = vdim or embedding_dim + + self.head_dim = embedding_dim // num_heads + self.scale = (self.head_dim * scale_factor)**-0.5 + + self.q_proj = nn.Linear(embedding_dim, embedding_dim, bias=qkv_bias) + self.k_proj = nn.Linear(self.kdim, embedding_dim, bias=qkv_bias) + self.v_proj = nn.Linear(self.vdim, embedding_dim, bias=qkv_bias) + self.out_proj = nn.Linear(embedding_dim, embedding_dim, bias=proj_bias) + + self.attn_drop = nn.Dropout(p=attn_drop) + + if scale_heads: + self.c_attn = nn.Parameter(torch.ones(num_heads)) + else: + self.c_attn = None + + def forward( + self, + query, + key_value=None, + attn_mask=None, + attn_bias=None, + past_key_value=None, + output_attentions=False, + ): + B, _, C = query.shape + assert C == self.head_dim * self.num_heads + + is_cross_attention = key_value is not None + if key_value is None: + key_value = query + + # (B, L, C) -> (B, num_heads, L, head_dims) + q = self.q_proj(query).reshape(B, -1, self.num_heads, + self.head_dim).transpose(1, 2) + + if is_cross_attention and past_key_value is not None: + # Reuse key and value in cross_attentions + k, v = past_key_value + else: + k = self.k_proj(key_value).reshape(B, -1, self.num_heads, + self.head_dim).transpose(1, 2) + v = self.v_proj(key_value).reshape(B, -1, self.num_heads, + self.head_dim).transpose(1, 2) + if past_key_value is not None: + past_key, past_value = past_key_value + k = torch.cat([past_key, k], dim=2) + v = torch.cat([past_value, v], dim=2) + + past_key_value = (k, v) + + attn_weights = q @ k.transpose(-2, -1) * self.scale + + if attn_bias is not None: + src_len = k.size(2) + attn_weights[:, :, -src_len:] += attn_bias[:, :, -src_len:] + + if attn_mask is not None: + attn_weights += attn_mask + attn_weights = torch.softmax(attn_weights, dim=-1) + attn = self.attn_drop(attn_weights) @ v + + if self.c_attn is not None: + attn = torch.einsum('bhlc,h->bhlc', attn, self.c_attn) + + # (B, num_heads, L, head_dims) -> (B, L, C) + attn = attn.transpose(1, 2).reshape(B, -1, self.embedding_dim) + attn = self.out_proj(attn) + + if output_attentions: + return attn, attn_weights, past_key_value + else: + return attn, None, past_key_value + + +@MODELS.register_module(force=True) +class OFAResNet(ResNet): + """ResNet module for OFA. + + The ResNet in OFA has only three stages. + """ + arch_settings = { + 50: (Bottleneck, (3, 4, 6)), + 101: (Bottleneck, (3, 4, 23)), + 152: (Bottleneck, (3, 8, 36)), + } + + def __init__(self, depth, *args, **kwargs): + super().__init__( + depth=depth, + *args, + num_stages=3, + out_indices=(2, ), + dilations=(1, 1, 1), + strides=(1, 2, 2), + **kwargs) + + +@dataclass +class OFAEncoderOutput(ModelOutput): + """OFA encoder outputs. + + Args: + last_hidden_state (torch.tensor): The hidden-states of the output at + the last layer of the model. The shape is (B, L, C). + hidden_states (Tuple[torch.tensor]): The initial embedding and the + output of each layer. The shape of every item is (B, L, C). + attentions (Tuple[torch.tensor]): The attention weights after the + attention softmax, used to compute the weighted average in the + self-attention heads. The shape of every item is + (B, num_heads, L, L). + position_embedding (torch.tensor): The positional embeddings of the + inputs. The shape is (B, L, C). + """ + + last_hidden_state: torch.FloatTensor = None + padding_mask: torch.Tensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + position_embedding: Optional[torch.FloatTensor] = None + + +class OFAEncoderLayer(nn.Module): + """OFAEncoder layer block.""" + + def __init__(self, + embedding_dim, + num_heads, + dropout_rate=0., + drop_path_rate=0., + attn_drop=0., + act_drop=0., + scale_factor=2., + mlp_ratio=4., + scale_heads=True, + normformer=True, + pre_norm=True, + act_cfg=dict(type='GELU')): + super().__init__() + self.embedding_dim = embedding_dim + self.pre_norm = pre_norm + + self.attn = MultiheadAttention( + embedding_dim=embedding_dim, + num_heads=num_heads, + attn_drop=attn_drop, + scale_factor=scale_factor, + scale_heads=scale_heads, + ) + + mid_channels = int(embedding_dim * mlp_ratio) + self.fc1 = nn.Linear(embedding_dim, mid_channels) + self.fc2 = nn.Linear(mid_channels, embedding_dim) + self.act = MODELS.build(act_cfg) + self.act_drop = nn.Dropout( + act_drop) if act_drop > 0. else nn.Identity() + + # LayerNorm between attention block and ffn block. + self.attn_ln = nn.LayerNorm(embedding_dim) + self.ffn_ln = nn.LayerNorm(embedding_dim) + + # Extra LayerNorm + self.normformer = normformer + if self.normformer: + self.attn_mid_ln = nn.LayerNorm(embedding_dim) + self.ffn_mid_ln = nn.LayerNorm(mid_channels) + + self.dropout = nn.Dropout(dropout_rate) + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward(self, + x, + attention_mask=None, + attn_bias=None, + output_attentions=False): + """Forward the encoder layer. + + Args: + x (torch.tensor): The input to the layer of shape ``(B, L, C)``. + attention_mask (torch.Tensor, optional): The attention mask of size + ``(B, 1, L, L)``, where padding elements are indicated by very + large negative values. Defaults to None. + attn_bias (torch.tensor, optional): The bias for positional + information. Defaults to None. + output_attentions (bool): Whether to return the attentions tensors + of the attention layer. + + Returns: + List[torch.tensor]: The first element is the encoded output of + shape ``(B, L, C)``. And the second element is the output + attentions if ``output_attentions=True``. + """ + residual = x + + # Attention block + if self.pre_norm: + x = self.attn_ln(x) + x, attn_weights, _ = self.attn( + query=x, + attn_mask=attention_mask, + attn_bias=attn_bias, + output_attentions=output_attentions) + if self.normformer: + x = self.attn_mid_ln(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.attn_ln(x) + + residual = x + + # FFN block + if self.pre_norm: + x = self.ffn_ln(x) + x = self.act(self.fc1(x)) + x = self.act_drop(x) + if self.normformer: + x = self.ffn_mid_ln(x) + x = self.fc2(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.ffn_ln(x) + + if output_attentions: + return [x, attn_weights] + else: + return [x] + + +class OFADecoderLayer(nn.Module): + """OFADecoder layer block.""" + + def __init__(self, + embedding_dim, + num_heads, + dropout_rate=0., + drop_path_rate=0., + attn_drop=0., + act_drop=0., + scale_factor=2., + mlp_ratio=4., + encoder_embed_dim=None, + scale_heads=True, + normformer=True, + pre_norm=True, + act_cfg=dict(type='GELU')): + super().__init__() + self.embedding_dim = embedding_dim + self.pre_norm = pre_norm + + self.self_attn = MultiheadAttention( + embedding_dim=embedding_dim, + num_heads=num_heads, + attn_drop=attn_drop, + scale_factor=scale_factor, + scale_heads=scale_heads, + ) + + self.cross_attn = MultiheadAttention( + embedding_dim=embedding_dim, + kdim=encoder_embed_dim, + vdim=encoder_embed_dim, + num_heads=num_heads, + attn_drop=attn_drop, + scale_factor=scale_factor, + scale_heads=scale_heads, + ) + + mid_channels = int(embedding_dim * mlp_ratio) + self.fc1 = nn.Linear(embedding_dim, mid_channels) + self.fc2 = nn.Linear(mid_channels, embedding_dim) + self.act = MODELS.build(act_cfg) + self.act_drop = nn.Dropout( + act_drop) if act_drop > 0. else nn.Identity() + + # LayerNorm between attention block and ffn block. + self.self_attn_ln = nn.LayerNorm(embedding_dim) + self.cross_attn_ln = nn.LayerNorm(embedding_dim) + self.ffn_ln = nn.LayerNorm(embedding_dim) + + # Extra LayerNorm + self.normformer = normformer + if self.normformer: + self.self_attn_mid_ln = nn.LayerNorm(embedding_dim) + self.cross_attn_mid_ln = nn.LayerNorm(embedding_dim) + self.ffn_mid_ln = nn.LayerNorm(mid_channels) + + self.dropout = nn.Dropout(dropout_rate) + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward( + self, + x, + attention_mask=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[List[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + self_attn_bias: Optional[torch.Tensor] = None, + cross_attn_bias: Optional[torch.Tensor] = None, + ): + """Forward the decoder layer. + + Args: + x (torch.tensor): The input to the layer of shape ``(B, L, C)``. + attention_mask (torch.Tensor, optional): The attention mask of size + ``(B, 1, L, L)``, where padding elements are indicated by very + large negative values. Defaults to None. + encoder_hidden_states (torch.Tensor, optional): The cross attention + input to the layer of size ``(B, L, C)``. Defaults to None. + encoder_attention_mask (torch.Tensor, optional): The cross + attention mask where padding elements are indicated by very + large negative values. Defaults to None. + past_key_value (Tuple[torch.tensor], optional): The cached past key + and value projection states. Defaults to none. + output_attentions (bool): whether to return the attentions tensors + of all attention layers. Defaults to False. + use_cache (bool, optional): Whether to use cache. + Defaults to False. + self_attn_bias (torch.Tensor, optional): The self attention bias + for positional information. Defaults to None. + cross_attn_bias (torch.Tensor, optional): The cross attention bias + for positional information. Defaults to None. + + Returns: + List[torch.tensor]: The first element is the encoded output of + shape ``(B, L, C)``. The following two elements can be the output + self-attentions and cross-attentions if ``output_attentions=True``. + The following one element can be the cached past key and value + projection states. + """ + residual = x + + if past_key_value is not None: + self_past_key_value = past_key_value[:2] + cross_past_key_value = past_key_value[2:] + else: + self_past_key_value, cross_past_key_value = None, None + + # Self-Attention block + if self.pre_norm: + x = self.self_attn_ln(x) + x, self_attn_weights, present_key_value = self.self_attn( + query=x, + past_key_value=self_past_key_value, + attn_mask=attention_mask, + output_attentions=output_attentions, + attn_bias=self_attn_bias, + ) + if self.normformer: + x = self.self_attn_mid_ln(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.self_attn_ln(x) + + # Cross-Attention block + if encoder_hidden_states is not None: + residual = x + if self.pre_norm: + x = self.cross_attn_ln(x) + x, cross_attn_weights, cross_key_value = self.cross_attn.forward( + query=x, + key_value=encoder_hidden_states, + attn_mask=encoder_attention_mask, + past_key_value=cross_past_key_value, + output_attentions=output_attentions, + attn_bias=cross_attn_bias) + if self.normformer: + x = self.cross_attn_mid_ln(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.cross_attn_ln(x) + + present_key_value = present_key_value + cross_key_value + + residual = x + + # FFN block + if self.pre_norm: + x = self.ffn_ln(x) + x = self.act(self.fc1(x)) + x = self.act_drop(x) + if self.normformer: + x = self.ffn_mid_ln(x) + x = self.fc2(x) + x = self.dropout(x) + x = residual + self.drop_path(x) + if not self.pre_norm: + x = self.ffn_ln(x) + + outputs = [x] + + if output_attentions: + outputs.extend([self_attn_weights, cross_attn_weights]) + + if use_cache: + outputs.append(present_key_value) + + return outputs + + +class OFAEncoder(BaseModule): + """The encoder module of OFA. + + Args: + embed_tokens (nn.Embedding): The embedding module to embed the + input tokens. + embed_images (dict | nn.Module): The module to embed the input + images into features. The output number of channels should + be 1024. + num_layers (int): The number of encoder layers. Defaults to 6. + num_heads (int): The number of heads of attention. Defaults to 12. + dropout_rate (float): The prob of dropout for embedding and + transformer layers. Defaults to 0. + drop_path_rate (float): The prob of droppath for transformer layers. + Defaults to 0. + max_source_positions (int): The maximum length of the input tokens. + Defaults to 1024. + token_bucket_size (int): The token bucket size, it's used as the + maximum relative position index in relative position embedding + of input tokens. Defaults to 256. + image_bucket_size (int): The image bucket size, it's used to generate + the image relative position embedding table. It should be larger + than the shape of image feature map. Defaults to 42. + attn_scale_factor (float): The scale factor to calculate qk scale in + attentions. Defaults to 2. + scale_embedding (bool): Whether to scale the embeddings by the square + root of the dimension. Defaults to False. + add_embedding_ln (bool): Whether to add an extra layer norm for token + embeddings. Defaults to True. + add_image_embedding_ln (bool): Whether to add an extra layer norm for + image embeddings. Defaults to True. + pre_norm (bool): Whether to do layer norm before attention and ffn + blocks in transformer layers. Defaults to True. + entangle_position_embedding (bool): Whether to add the position + embedding on the embeddings directly. Defaults to False. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + def __init__( + self, + embed_tokens, + embed_images: dict, + num_layers=6, + num_heads=12, + dropout_rate=0., + drop_path_rate=0., + max_source_positions=1024, + token_bucket_size=256, + image_bucket_size=42, + attn_scale_factor=2., + scale_embedding=False, + add_embedding_ln=True, + add_type_embed=True, + add_image_embedding_ln=True, + pre_norm=True, + entangle_position_embedding=False, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + + self.num_layers = num_layers + embedding_dim = embed_tokens.embedding_dim + self.embedding_dim = embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_source_positions = max_source_positions + self.num_heads = num_heads + + # Build embedding process components + self.embed_tokens = embed_tokens + self.embedding_scale = math.sqrt( + embedding_dim) if scale_embedding else 1.0 + + if not isinstance(embed_images, nn.Module): + self.embed_images = MODELS.build(embed_images) + else: + self.embed_images = embed_images + self.image_proj = nn.Linear(1024, embedding_dim) + + if add_embedding_ln: + self.embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.embedding_ln = None + + if add_type_embed: + self.embed_type = nn.Embedding(2, embedding_dim) + else: + self.embed_type = None + + if add_image_embedding_ln: + self.image_embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.image_embedding_ln = None + + self.entangle_position_embedding = entangle_position_embedding + + # Build position embedding + self.embed_positions = nn.Embedding(self.max_source_positions + 2, + embedding_dim) + self.pos_ln = nn.LayerNorm(embedding_dim) + self.embed_image_positions = nn.Embedding(image_bucket_size**2 + 1, + embedding_dim) + self.image_pos_ln = nn.LayerNorm(embedding_dim) + + self.pos_scaling = float(embedding_dim / num_heads * + attn_scale_factor)**-0.5 + self.pos_q_linear = nn.Linear(embedding_dim, embedding_dim) + self.pos_k_linear = nn.Linear(embedding_dim, embedding_dim) + + self.dropout = nn.Dropout( + dropout_rate) if dropout_rate > 0. else nn.Identity() + + # Register token relative position embedding table + self.token_bucket_size = token_bucket_size + token_num_rel_dis = 2 * token_bucket_size - 1 + token_rp_bucket = make_token_bucket_position(token_bucket_size, + self.max_source_positions) + self.register_buffer('token_rp_bucket', token_rp_bucket) + self.token_rel_pos_table_list = nn.ModuleList() + + # Register image relative position embedding table + self.image_bucket_size = image_bucket_size + image_num_rel_dis = (2 * image_bucket_size - + 1) * (2 * image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position(image_bucket_size, + image_num_rel_dis) + self.register_buffer('image_rp_bucket', image_rp_bucket) + self.image_rel_pos_table_list = nn.ModuleList() + + # Build encoder layers + self.layers = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + for index in range(self.num_layers): + layer = OFAEncoderLayer( + embedding_dim=embedding_dim, + num_heads=num_heads, + dropout_rate=dropout_rate, + drop_path_rate=dpr[index], + scale_factor=attn_scale_factor, + pre_norm=pre_norm, + ) + self.layers.append(layer) + token_pos_table = nn.Embedding(token_num_rel_dis, self.num_heads) + image_pos_table = nn.Embedding(image_num_rel_dis, self.num_heads) + nn.init.constant_(token_pos_table.weight, 0.) + nn.init.constant_(image_pos_table.weight, 0.) + self.token_rel_pos_table_list.append(token_pos_table) + self.image_rel_pos_table_list.append(image_pos_table) + + if pre_norm: + self.final_ln = nn.LayerNorm(embedding_dim) + else: + self.final_ln = None + + main_input_name = 'input_ids' + + def forward(self, + input_ids, + images, + images_mask, + output_attentions=False, + output_hidden_states=False, + sample_patch_num=None): + padding_mask = input_ids.eq(self.padding_idx) + has_pads = padding_mask.any() + token_embedding = self.embed_tokens(input_ids) + token_embedding = self.embedding_scale * token_embedding + + # Embed the token position + src_pos_idx = torch.arange(input_ids.size(-1), device=input_ids.device) + src_pos_idx = src_pos_idx.expand(*input_ids.shape).contiguous() + pos_embedding = self.embed_positions(src_pos_idx) + + # Embed the input tokens + x = self.process_embedding( + embedding=token_embedding, + type_tokens=input_ids.new_zeros(token_embedding.shape[:2]), + pos_embedding=pos_embedding, + embedding_ln=self.embedding_ln, + ) + pos_embedding = self.pos_ln(pos_embedding) + + # Embed the input images + if images is not None: + (image_tokens, image_padding_mask, image_position_ids, + image_pos_embedding) = self.get_image_tokens( + images, + sample_patch_num, + images_mask, + ) + image_embedding = self.image_proj(image_tokens) + + image_x = self.process_embedding( + embedding=image_embedding, + type_tokens=input_ids.new_ones(image_embedding.shape[:2]), + pos_embedding=image_pos_embedding, + embedding_ln=self.image_embedding_ln, + ) + image_pos_embedding = self.image_pos_ln(image_pos_embedding) + + x = torch.cat([image_x, x], dim=1) + padding_mask = torch.cat([image_padding_mask, padding_mask], dim=1) + pos_embedding = torch.cat([image_pos_embedding, pos_embedding], + dim=1) + + # account for padding while computing the representation + if has_pads: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + + # Decoupled position embedding + B, L = pos_embedding.shape[:2] + pos_q = self.pos_q_linear(pos_embedding).view( + B, L, self.num_heads, -1).transpose(1, 2) * self.pos_scaling + pos_k = self.pos_k_linear(pos_embedding).view(B, L, self.num_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + + all_hidden_states = [] if output_hidden_states else None + all_attentions = [] if output_attentions else None + + for idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(x) + + self_attn_bias = abs_pos_bias.clone() + # Add decoupled position embedding for input tokens. + token_len = input_ids.size(1) + rel_pos_bias = self.get_rel_pos_bias(input_ids, idx) + self_attn_bias[:, :, -token_len:, -token_len:] += rel_pos_bias + + # Add decoupled position embedding for images + if images is not None: + token_len = image_tokens.size(1) + rel_pos_bias = self.get_image_rel_pos_bias( + image_position_ids, idx) + self_attn_bias[:, :, :token_len, :token_len] += rel_pos_bias + + if has_pads: + attention_mask = _expand_mask(padding_mask, dtype=x.dtype) + else: + attention_mask = None + + out = layer( + x, + attention_mask=attention_mask, + attn_bias=self_attn_bias, + output_attentions=output_attentions) + x = out[0] + + if output_attentions: + all_attentions.append(out[1]) + + if output_hidden_states: + all_hidden_states.append(x) + + if self.final_ln is not None: + x = self.final_ln(x) + + return OFAEncoderOutput( + last_hidden_state=x, # (B, L, C) + padding_mask=padding_mask, # (B, L) + position_embedding=pos_embedding, # (B, L, C) + hidden_states=all_hidden_states, # list of (B, L, C) + attentions=all_attentions, # list of (B, num_heads, L, head_dims) + ) + + def get_image_tokens(self, images, sample_patch_num, images_mask): + image_embedding = self.embed_images(images)[-1] + B, C, H, W = image_embedding.shape + num_patches = H * W + + padding_mask = images.new_zeros((B, num_patches)).bool() + position_col = torch.arange(W).unsqueeze(0) + position_row = torch.arange(H).unsqueeze(1) * self.image_bucket_size + position_idx = (position_col + position_row + 1).view(-1) + position_idx = position_idx.to(images.device).expand(B, num_patches) + + # (B, C, H, W) -> (B, C, H*W) -> (B, H*W, C) + image_embedding = image_embedding.flatten(2).transpose(1, 2) + if sample_patch_num is not None: + patch_orders = torch.stack([ + torch.randperm(num_patches)[:sample_patch_num] + for _ in range(B) + ]) + num_patches = sample_patch_num + image_embedding = image_embedding.gather( + dim=1, index=patch_orders.unsqueeze(2).expand(-1, -1, C)) + padding_mask = padding_mask.gather(1, patch_orders) + position_idx = position_idx.gather(1, patch_orders) + + pos_embedding = self.embed_image_positions(position_idx) + padding_mask[~images_mask] = True + return image_embedding, padding_mask, position_idx, pos_embedding + + def process_embedding(self, + embedding, + pos_embedding=None, + type_tokens=None, + embedding_ln=None): + if self.entangle_position_embedding and pos_embedding is not None: + embedding += pos_embedding + if self.embed_type is not None: + embedding += self.embed_type(type_tokens) + if embedding_ln is not None: + embedding = embedding_ln(embedding) + embedding = self.dropout(embedding) + + return embedding + + def get_rel_pos_bias(self, x, idx): + seq_len = x.size(1) + rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] + values = F.embedding(rp_bucket, + self.token_rel_pos_table_list[idx].weight) + values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) + values = values.permute([0, 3, 1, 2]) + return values.contiguous() + + def get_image_rel_pos_bias(self, image_position_ids, idx): + bsz, seq_len = image_position_ids.shape + rp_bucket_size = self.image_rp_bucket.size(1) + + rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( + bsz, rp_bucket_size, rp_bucket_size).gather( + 1, image_position_ids[:, :, None].expand( + bsz, seq_len, rp_bucket_size)).gather( + 2, image_position_ids[:, None, :].expand( + bsz, seq_len, seq_len)) + values = F.embedding(rp_bucket, + self.image_rel_pos_table_list[idx].weight) + values = values.permute(0, 3, 1, 2) + return values + + +class OFADecoder(BaseModule): + """The decoder module of OFA. + + Args: + embed_tokens (nn.Embedding): The embedding module to embed the + input tokens. + num_layers (int): The number of decoder layers. Defaults to 6. + num_heads (int): The number of heads of attention. Defaults to 12. + dropout_rate (float): The prob of dropout for embedding and + transformer layers. Defaults to 0. + drop_path_rate (float): The prob of droppath for transformer layers. + Defaults to 0. + max_target_positions (int): The maximum length of the input tokens. + Defaults to 1024. + code_image_size (int): The resolution of the generated image in the + image infilling task. Defaults to 128. + token_bucket_size (int): The token bucket size, it's used as the + maximum relative position index in relative position embedding + of input tokens. Defaults to 256. + image_bucket_size (int): The image bucket size, it's used to generate + the image relative position embedding table. It should be larger + than the shape of image feature map. Defaults to 42. + attn_scale_factor (float): The scale factor to calculate qk scale in + attentions. Defaults to 2. + scale_embedding (bool): Whether to scale the embeddings by the square + root of the dimension. Defaults to False. + add_embedding_ln (bool): Whether to add an extra layer norm for token + embeddings. Defaults to True. + add_code_embedding_ln (bool): Whether to add an extra layer norm for + code embeddings. Defaults to True. + pre_norm (bool): Whether to do layer norm before attention and ffn + blocks in transformer layers. Defaults to True. + entangle_position_embedding (bool): Whether to add the position + embedding on the embeddings directly. Defaults to False. + share_input_output_embed (bool): Share the weights of the input token + embedding module and the output projection module. + Defaults to True. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + def __init__( + self, + embed_tokens, + num_layers=6, + num_heads=12, + dropout_rate=0., + drop_layer_rate=0., + drop_path_rate=0., + max_target_positions=1024, + code_image_size=128, + token_bucket_size=256, + image_bucket_size=42, + attn_scale_factor=2., + scale_embedding=False, + add_embedding_ln=True, + add_code_embedding_ln=True, + pre_norm=True, + entangle_position_embedding=False, + share_input_output_embed=True, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self._future_mask = torch.empty(0) + + self.num_layers = num_layers + embedding_dim = embed_tokens.embedding_dim + self.embedding_dim = embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = max_target_positions + self.num_heads = num_heads + + # Build embedding process components + self.embed_tokens = embed_tokens + self.embedding_scale = math.sqrt( + embedding_dim) if scale_embedding else 1.0 + + if add_embedding_ln: + self.embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.embedding_ln = None + + if add_code_embedding_ln: + self.code_embedding_ln = nn.LayerNorm(embedding_dim) + else: + self.code_embedding_ln = None + + # Build position embedding + self.embed_positions = nn.Embedding(self.max_target_positions + 2, + embedding_dim) + self.pos_ln = nn.LayerNorm(embedding_dim) + self.embed_image_positions = nn.Embedding(image_bucket_size**2 + 1, + embedding_dim) + self.image_pos_ln = nn.LayerNorm(embedding_dim) + + self.pos_scaling = float(embedding_dim / num_heads * + attn_scale_factor)**-0.5 + self.self_pos_q_linear = nn.Linear(embedding_dim, embedding_dim) + self.self_pos_k_linear = nn.Linear(embedding_dim, embedding_dim) + self.cross_pos_q_linear = nn.Linear(embedding_dim, embedding_dim) + self.cross_pos_k_linear = nn.Linear(embedding_dim, embedding_dim) + + self.entangle_position_embedding = entangle_position_embedding + + self.dropout = nn.Dropout( + dropout_rate) if dropout_rate > 0. else nn.Identity() + if drop_layer_rate > 0.: + raise NotImplementedError + + # Register token relative position embedding table + self.token_bucket_size = token_bucket_size + token_num_rel_dis = 2 * token_bucket_size - 1 + token_rp_bucket = make_token_bucket_position(token_bucket_size) + self.register_buffer('token_rp_bucket', token_rp_bucket) + self.token_rel_pos_table_list = nn.ModuleList() + + # Register image relative position embedding table + self.image_bucket_size = image_bucket_size + image_num_rel_dis = (2 * image_bucket_size - + 1) * (2 * image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position(image_bucket_size, + image_num_rel_dis) + self.register_buffer('image_rp_bucket', image_rp_bucket) + self.image_rel_pos_table_list = nn.ModuleList() + + self.window_size = code_image_size // 8 + + position_col = torch.arange(self.window_size).unsqueeze(0) + position_row = torch.arange( + self.window_size).unsqueeze(1) * self.image_bucket_size + image_position_idx = (position_col + position_row + 1) + image_position_idx = torch.cat( + [torch.tensor([0]), image_position_idx.view(-1)]) + image_position_idx = torch.cat( + [image_position_idx, + torch.tensor([1024] * 768)]) + self.register_buffer('image_position_idx', image_position_idx) + + # Build decoder layers + self.layers = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + for index in range(self.num_layers): + layer = OFADecoderLayer( + embedding_dim=embedding_dim, + num_heads=num_heads, + dropout_rate=dropout_rate, + drop_path_rate=dpr[index], + scale_factor=attn_scale_factor, + pre_norm=pre_norm, + ) + self.layers.append(layer) + token_pos_table = nn.Embedding(token_num_rel_dis, self.num_heads) + image_pos_table = nn.Embedding(image_num_rel_dis, self.num_heads) + nn.init.constant_(token_pos_table.weight, 0.) + nn.init.constant_(image_pos_table.weight, 0.) + self.token_rel_pos_table_list.append(token_pos_table) + self.image_rel_pos_table_list.append(image_pos_table) + + if pre_norm: + self.final_ln = nn.LayerNorm(embedding_dim) + else: + self.final_ln = None + + # Build output projection + if share_input_output_embed: + self.output_projection = nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = self.embed_tokens.weight + else: + vocab_size = self.embed_tokens.num_embeddings + self.output_projection = nn.Linear( + embedding_dim, vocab_size, bias=False) + nn.init.normal_( + self.output_projection.weight, + mean=0, + std=embedding_dim**-0.5, + ) + + main_input_name = 'input_ids' + + def forward( + self, + input_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + encoder_hidden_states: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, + code_masks: Optional[torch.Tensor] = None, + encoder_pos_embedding: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + + if past_key_values is not None and len(past_key_values) > 0: + B, _, L_past, _ = past_key_values[0][0].shape + L = L_past + 1 + else: + B, L = input_ids.shape + L_past = 0 + + # Embed the token position + target_pos_idx = torch.arange( + L, device=input_ids.device).expand([B, L]).contiguous() + pos_embedding = self.embed_positions(target_pos_idx) + + # Embed the code positions + if code_masks is not None and torch.any(code_masks): + image_position_idx = self.image_position_idx[:input_ids.size(1)] + image_position_idx = image_position_idx.unsqueeze(0).expand(B, L) + pos_embedding[code_masks] = self.embed_image_positions( + image_position_idx)[code_masks] + + # Self-attention position bias (B, num_heads, L_t, L_t) + self_abs_pos_bias = self.get_pos_info(self.pos_ln(pos_embedding)) + if code_masks is not None and torch.any(code_masks): + self_image_abs_pos_bias = self.get_pos_info( + self.image_pos_ln(pos_embedding)) + self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks] + + # Cross-attention position bias (B, num_heads, L_t, L_s) + cross_abs_pos_bias = self.get_pos_info( + self.pos_ln(pos_embedding), encoder_pos_embedding) + if code_masks is not None and torch.any(code_masks): + cross_image_abs_pos_bias = self.get_pos_info( + self.image_pos_ln(pos_embedding), encoder_pos_embedding) + cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[ + code_masks] + + all_prev_output_tokens = input_ids.clone() + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + cross_abs_pos_bias = cross_abs_pos_bias[:, :, -1:, :] + pos_embedding = pos_embedding[:, -1:, :] + + # Embed the input tokens + x = self.embed_tokens(input_ids) * self.embedding_scale + + if self.entangle_position_embedding: + x += pos_embedding + + if self.embedding_ln is not None: + if (code_masks is None or not code_masks.any() + or self.code_embedding_ln is None): + x = self.embedding_ln(x) + elif code_masks is not None and code_masks.all(): + x = self.code_embedding_ln(x) + else: + x[~code_masks] = self.embedding_ln(x[~code_masks]) + x[code_masks] = self.code_embedding_ln(x[code_masks]) + + x = self.dropout(x) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_ids.shape, x.dtype, L_past) + attention_mask = attention_mask.to(x.device) + + # decoder layers + all_hidden_states = [] if output_hidden_states else None + all_self_attns = [] if output_attentions else None + all_cross_attentions = [] if ( + output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = [] if use_cache else None + + for idx, layer in enumerate(self.layers): + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states.append(x) + + if past_key_values is not None and len(past_key_values) > 0: + past_key_value = past_key_values[idx] + else: + past_key_value = None + + self_attn_bias = self_abs_pos_bias.clone() + if code_masks is None or not code_masks.any(): + self_attn_bias += self.get_rel_pos_bias( + all_prev_output_tokens, idx) + elif code_masks is not None and code_masks.all(): + self_attn_bias += self.get_image_rel_pos_bias( + all_prev_output_tokens, idx) + else: + self_attn_bias[~code_masks] += self.get_rel_pos_bias( + all_prev_output_tokens, idx) + self_attn_bias[code_masks] += self.get_image_rel_pos_bias( + all_prev_output_tokens, idx) + + if past_key_value is not None: + self_attn_bias = self_attn_bias[:, :, -1:, :] + + out = layer( + x, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + self_attn_bias=self_attn_bias, + cross_attn_bias=cross_abs_pos_bias, + ) + x = out.pop(0) + + if output_attentions: + all_self_attns.append(out.pop(0)) + if encoder_hidden_states is not None: + all_cross_attentions.append(out.pop(0)) + + if use_cache: + next_decoder_cache.append(out.pop(0)) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (x, ) + + if self.final_ln is not None: + x = self.final_ln(x) + + x = self.output_projection(x) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=x, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + def _prepare_decoder_attention_mask( + self, + attention_mask, + input_shape, + dtype, + past_key_values_length, + ): + r""" + Create causal mask for unidirectional decoding. + [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + """ + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + dtype, + past_key_values_length=past_key_values_length).to( + attention_mask.device) + + if attention_mask is not None: + # (B, L_s) -> (B, 1, L_t, L_s) + expanded_attention_mask = _expand_mask( + attention_mask, dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attention_mask if combined_attention_mask is None else + expanded_attention_mask + combined_attention_mask) + + return combined_attention_mask + + def get_pos_info(self, pos_embedding, src_pos_embedding=None): + B, tgt_len = pos_embedding.shape[:2] + if src_pos_embedding is not None: + src_len = src_pos_embedding.size(1) + pos_q = self.cross_pos_q_linear(pos_embedding).view( + B, tgt_len, self.num_heads, -1).transpose(1, 2) + pos_q = pos_q * self.pos_scaling + pos_k = self.cross_pos_k_linear(src_pos_embedding).view( + B, src_len, self.num_heads, -1).transpose(1, 2) + else: + pos_q = self.self_pos_q_linear(pos_embedding).view( + B, tgt_len, self.num_heads, -1).transpose(1, 2) + pos_q = pos_q * self.pos_scaling + pos_k = self.self_pos_k_linear(pos_embedding).view( + B, tgt_len, self.num_heads, -1).transpose(1, 2) + + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + + return abs_pos_bias + + def get_rel_pos_bias(self, x, idx): + seq_len = x.size(1) + rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] + values = F.embedding(rp_bucket, + self.token_rel_pos_table_list[idx].weight) + values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) + values = values.permute([0, 3, 1, 2]) + return values.contiguous() + + def get_image_rel_pos_bias(self, image_position_ids, idx): + bsz, seq_len = image_position_ids.shape + rp_bucket_size = self.image_rp_bucket.size(1) + + rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( + bsz, rp_bucket_size, rp_bucket_size).gather( + 1, image_position_ids[:, :, None].expand( + bsz, seq_len, rp_bucket_size)).gather( + 2, image_position_ids[:, None, :].expand( + bsz, seq_len, seq_len)) + values = F.embedding(rp_bucket, + self.image_rel_pos_table_list[idx].weight) + values = values.permute(0, 3, 1, 2) + return values + + +class OFAEncoderDecoder(BaseModule, GenerationMixin): + """The OFA main architecture with an encoder and a decoder. + + Args: + encoder_cfg (dict): The config of the encoder, accept the keyword + arguments of :class:`OFAEncoder`. + decoder_cfg (dict): The config of the decoder, accept the keyword + arguments of :class:`OFADecoder`. + padding_idx (int): The index of the padding token. + vocab_size (int): The size of the vocabulary. + embedding_dim (int): The embedding dimensions of both the encoder + and the decoder. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of :class:`~transformers.GenerationConfig`. + Defaults to an empty dict. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + base_model_prefix = '' + + def __init__( + self, + encoder_cfg, + decoder_cfg, + padding_idx, + vocab_size, + embedding_dim, + generation_cfg=dict(), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + + self.padding_idx = padding_idx + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + embed_tokens = nn.Embedding(vocab_size, embedding_dim, padding_idx) + + self.encoder = OFAEncoder(embed_tokens, **encoder_cfg) + self.decoder = OFADecoder(embed_tokens, **decoder_cfg) + + self.config = PretrainedConfig( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + bos_token_id=0, + decoder_start_token_id=0, + pad_token_id=1, + eos_token_id=2, + forced_eos_token_id=2, + use_cache=False, + is_encoder_decoder=True, + ) + self.config.update(generation_cfg) + + self.generation_config = GenerationConfig.from_model_config( + self.config) + + @property + def device(self): + return next(self.parameters()).device + + def can_generate(self): + return True + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + def get_normalized_probs(self, net_output, log_probs: bool, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, + sample) + + def get_normalized_probs_scriptable( + self, + net_output, + log_probs: bool, + sample=None, + ): + """Scriptable helper function for get_normalized_probs in. + + ~BaseFairseqModel. + """ + if hasattr(self, 'decoder'): + return self.decoder.get_normalized_probs(net_output, log_probs, + sample) + elif torch.is_tensor(net_output): + # syntactic sugar for simple models which don't have a decoder + # (e.g., the classification tutorial) + logits = net_output.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + main_input_name = 'input_ids' + + def forward(self, + input_ids=None, + images=None, + images_mask=None, + sample_patch_num=None, + decoder_input_ids=None, + code_masks=None, + attention_mask=None, + encoder_outputs=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + constrain_fn=None, + return_dict=False): + """Forword the module. + + Args: + input_ids (torch.Tensor): The indices of the input tokens in the + vocabulary, and padding will be ignored by default. The indices + can be obtained using :class:`OFATokenizer`. + The shape is (B, L). + images (torch.Tensor): The input images. The shape is (B, 3, H, W). + images_mask (torch.Tensor): The mask of all available images. The + shape is (B, ). + sample_patch_num (int): The number of patches to sample for the + images. Defaults to None, which means to use all patches. + decoder_input_ids (torch.Tensor): The indices of the input tokens + for the decoder. + code_masks (torch.Tensor): The mask of all samples for image + generation. The shape is (B, ). + attention_mask (torch.Tensor): The attention mask for decoding. + The shape is (B, L). + encoder_outputs (OFAEncoderOutput): The encoder outputs with hidden + states, positional embeddings, and padding masks. + past_key_values (Tuple[Tuple[torch.Tensor]]): If use cache, the + parameter is a tuple of length ``num_layers``. Every item is + also a tuple with four tensors, two for the key and value of + self-attention, two for the key and value of cross-attention. + use_cache (bool): Whether to use cache for faster inference. + Defaults to False. + output_attentions (bool): Whether to output attention weights. + Defaults to False. + output_hidden_states (bool): Whether to output hidden states. + Defaults to False. + constrain_fn (Callable, optional): The function to constrain the + output logits. Defaults to None. + return_dict (bool): Not used, it's only for compat with the + interface of the ``generate`` of ``transformers``. + + Returns: + Seq2SeqLMOutput: + + - logits (``torch.Tensor``): The last decoder hidden states. + The shape is (B, L, C). + - past_key_values (``Tuple[Tuple[torch.Tensor]]``): The past keys + and values for faster inference. + - decoder_hidden_states (``Tuple[torch.Tensor]``): the decoder + hidden states of all layers. + - decoder_attentions (``Tuple[torch.Tensor]``): The self-attention + weights of all layers in the decoder. + - cross_attentions (``Tuple[torch.Tensor]``): The cross-attention + weights of all layers in the decoder. + - encoder_last_hidden_state (``torch.Tensor``): The last encoder + hidden states. + - encoder_hidden_states (``Tuple[torch.Tensor]``): The encoder + hidden states of all layers, including the embeddings. + - encoder_attentions (``Tuple[torch.Tensor]``): The self-attention + weights of all layers in the encoder. + """ + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + images=images, + images_mask=images_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + sample_patch_num=sample_patch_num, + ) + + if decoder_input_ids.eq(self.padding_idx).any(): + attention_mask = decoder_input_ids.eq(self.padding_idx) + + encoder_hidden_states = encoder_outputs.last_hidden_state + encoder_attention_mask = _expand_mask(encoder_outputs.padding_mask, + encoder_hidden_states.dtype, + decoder_input_ids.shape[-1]) + src_pos_embed = encoder_outputs.position_embedding + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + code_masks=code_masks, + encoder_pos_embedding=src_pos_embed, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + # The constrain operation for fine-tuned model in OFA is applied + # before log_softmax, therefore we cannot use + # `prefix_allowed_tokens_fn` to implement it. + if constrain_fn is not None: + logits = constrain_fn(decoder_input_ids, + decoder_outputs.last_hidden_state) + else: + logits = decoder_outputs.last_hidden_state + + return Seq2SeqLMOutput( + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, + decoder_input_ids=None, + past=None, + attention_mask=None, + code_masks=None, + use_cache=False, + encoder_outputs=None, + constrain_fn=None, + **kwargs): + # if attention_mask is None: + attention_mask = decoder_input_ids.new_zeros(decoder_input_ids.shape) + + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + 'input_ids': None, + 'images': None, + 'images_mask': None, + 'sample_patch_num': None, + 'attention_mask': attention_mask, + 'encoder_outputs': encoder_outputs, + 'past_key_values': past, + 'decoder_input_ids': decoder_input_ids, + 'code_masks': code_masks, + 'use_cache': use_cache, + 'constrain_fn': constrain_fn, + } + + def _prepare_encoder_decoder_kwargs_for_generation( + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str] = None): + # 1. get encoder + encoder = self.get_encoder() + + # 2. prepare encoder args and encoder kwargs from model kwargs + irrelevant_prefix = [ + 'decoder_', 'cross_attn', 'use_cache', 'attention_mask', + 'constrain_fn' + ] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + + if encoder_kwargs.get('images_mask') is None: + encoder_kwargs['images_mask'] = torch.tensor([True] * + inputs_tensor.size(0)) + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name or self.main_input_name + encoder_kwargs[model_input_name] = inputs_tensor + model_kwargs['encoder_outputs']: ModelOutput = encoder( + **encoder_kwargs) + model_kwargs['attention_mask'] = None + + return model_kwargs + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.LongTensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[ModelOutput] = None, + **model_kwargs, + ): + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat( + 1, expand_size).view(-1).to(input_ids.device)) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs['attention_mask'] = attention_mask.index_select( + 0, expanded_return_idx) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError('If `is_encoder_decoder` is True, make ' + 'sure that `encoder_outputs` is defined.') + encoder_outputs['last_hidden_state'] = encoder_outputs.\ + last_hidden_state.index_select(0, expanded_return_idx) + encoder_outputs['position_embedding'] = encoder_outputs.\ + position_embedding.index_select(0, expanded_return_idx) + encoder_outputs['padding_mask'] = encoder_outputs.\ + padding_mask.index_select(0, expanded_return_idx) + model_kwargs['encoder_outputs'] = encoder_outputs + return input_ids, model_kwargs diff --git a/mmpretrain/models/multimodal/otter/__init__.py b/mmpretrain/models/multimodal/otter/__init__.py new file mode 100644 index 0000000..38a45a3 --- /dev/null +++ b/mmpretrain/models/multimodal/otter/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .otter import Otter + +__all__ = ['Otter'] diff --git a/mmpretrain/models/multimodal/otter/otter.py b/mmpretrain/models/multimodal/otter/otter.py new file mode 100644 index 0000000..7d30b50 --- /dev/null +++ b/mmpretrain/models/multimodal/otter/otter.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from ..flamingo.flamingo import ExtendModule, Flamingo, PerceiverResampler + + +@MODELS.register_module() +class Otter(Flamingo): + """The Otter model for multiple tasks. + + Args: + vision_encoder (dict): The config of the vision encoder. + lang_encoder (dict): The config of the language encoder. + tokenizer (dict): The tokenizer to encode the text. + task (int): The task to perform prediction. + zeroshot_prompt (str): Prompt used for zero-shot inference. + Defaults to an. + shot_prompt_tmpl (str): Prompt used for few-shot inference. + Defaults to ``User:Please describe the image. + GPT:{caption}<|endofchunk|>``. + final_prompt_tmpl (str): Final part of prompt used for inference. + Defaults to 'User:Please describe the image. GPT:'. + generation_cfg (dict): The extra generation config, accept the keyword + arguments of [~`transformers.GenerationConfig`]. + Defaults to an empty dict. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. If None or no specified type, it will use + "MutimodalDataPreprocessor" as type. + See :class:`MutimodalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The initialization config. Defaults to None. + """ + + support_tasks = {'caption', 'vqa'} + _no_split_modules = [ + 'TransformerEncoderLayer', 'PerceiverAttention', + 'GatedCrossAttentionBlock', 'FlamingoLayer' + ] + + def __init__( + self, + vision_encoder: dict, + lang_encoder: dict, + tokenizer: dict, + task: str = 'caption', + zeroshot_prompt: str = '', + shot_prompt_tmpl: str = ('User:Please describe the image. ' + 'GPT:{caption}<|endofchunk|>'), + final_prompt_tmpl: str = ('User:Please describe the image. ' + 'GPT:'), + generation_cfg: dict = dict(), + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super(Flamingo, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if task not in self.support_tasks: + raise ValueError(f'Unsupported task {task}, please select ' + f'the task from {self.support_tasks}.') + self.task = task + + # init tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + # add Otter special tokens to the tokenizer + self.tokenizer.add_special_tokens({ + 'additional_special_tokens': + ['<|endofchunk|>', '', ''] + }) + self.tokenizer.bos_token_id = 1 + if self.tokenizer.pad_token is None: + # Issue: GPT models don't have a pad token, which we use to + # modify labels for the loss. + self.tokenizer.add_special_tokens({'pad_token': ''}) + + # Template to format the prompt input + self.zeroshot_prompt = zeroshot_prompt + self.shot_prompt_tmpl = shot_prompt_tmpl + self.final_prompt_tmpl = final_prompt_tmpl + + # init vision encoder related modules + vision_encoder_weight = vision_encoder.pop('pretrained', None) + self.vision_encoder = MODELS.build(vision_encoder) + if vision_encoder_weight is not None: + from mmengine.runner.checkpoint import load_checkpoint + load_checkpoint( + self.vision_encoder, + vision_encoder_weight, + map_location='cpu', + revise_keys=[(r'^backbone\.', '')], + ) + self.vision_encoder.is_init = True + + self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims) + + # init language encoder related modules + self.lang_encoder = ExtendModule(**lang_encoder) + self.lang_encoder.resize_token_embeddings(len(self.tokenizer)) + self.lang_encoder.media_token_id = self.tokenizer.encode('')[-1] + + # other necessary parameters + self.eoc_token_id = self.tokenizer.encode('<|endofchunk|>')[-1] + self.generation_cfg = generation_cfg + + if hasattr(self, 'register_load_state_dict_post_hook'): + self.register_load_state_dict_post_hook(self._load_adapter_hook) + + def post_process( + self, outputs: torch.Tensor, + data_samples: Optional[List[DataSample]]) -> List[DataSample]: + """Perform post process for outputs for different task. + + Args: + outputs (torch.Tensor): The generated outputs. + data_samples (List[DataSample], optional): The annotation + data of every samples. + + Returns: + List[DataSample]: Return list of data samples. + """ + outputs = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True) + + if data_samples is None: + data_samples = [DataSample() for _ in range(len(outputs))] + + for output, data_sample in zip(outputs, data_samples): + # remove text pattern + if self.task == 'caption': + data_sample.pred_caption = output + elif self.task == 'vqa': + data_sample.pred_answer = output + + return data_samples diff --git a/mmpretrain/models/multimodal/ram/__init__.py b/mmpretrain/models/multimodal/ram/__init__.py new file mode 100644 index 0000000..35619d8 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ram import RAM, RAMNormal, RAMOpenset + +__all__ = ['RAM', 'RAMNormal', 'RAMOpenset'] diff --git a/mmpretrain/models/multimodal/ram/bert.py b/mmpretrain/models/multimodal/ram/bert.py new file mode 100644 index 0000000..f54b2ce --- /dev/null +++ b/mmpretrain/models/multimodal/ram/bert.py @@ -0,0 +1,1197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modify from: +# https://github.com/xinyu1205/recognize-anything/blob/main/ram/models/bert.py + +import math +from typing import Tuple + +import torch +import torch.utils.checkpoint +from torch import Tensor, device, nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.models.bert.configuration_bert import BertConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class BertEmbeddings_nopos(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + # self.position_embeddings = nn.Embedding( + # config.max_position_embeddings, config.hidden_size) + '''self.LayerNorm is not snake-cased to stick with + TensorFlow model variable name and be able to load''' + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous + # in memory and exported when serialized + # self.register_buffer("position_ids", + # torch.arange(config.max_position_embeddings).expand((1, -1))) + # self.position_embedding_type = \ + # getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward(self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] # noqa: F841 + + # if position_ids is None: + # position_ids = self.position_ids[:, \ + # past_key_values_length : seq_length + \ + # past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + # if self.position_embedding_type == "absolute": + # position_embeddings = self.position_embeddings(position_ids) + # # print('add position_embeddings!!!!') + # embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with + # TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous + # in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward(self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + # print('add position_embeddings!!!!') + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and \ + not hasattr(config, 'embedding_size'): + raise ValueError('''The hidden size (%d) is not a multiple of + the number of attention heads (%d)''' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * \ + self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + # print(self.key.weight.shape) + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # compatible with higher versions of transformers + if key_layer.shape[0] > query_layer.shape[0]: + key_layer = key_layer[:query_layer.shape[0], :, :, :] + attention_mask = attention_mask[:query_layer.shape[0], :, :] + value_layer = value_layer[:query_layer.shape[0], :, :, :] + + # Take the dot product between "query" and "key" + # to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = attention_scores + \ + relative_position_scores_query + \ + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for + # all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * \ + self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + + if mode == 'tagging': + + assert encoder_hidden_states is not None, \ + '''encoder_hidden_states must be given + for cross-attention layers''' + + cross_attention_outputs = self.crossattention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + present_key_value = cross_attention_outputs[-1] + + else: + # decoder uni-directional self-attention + # cached key/values tuple is at positions 1,2 + self_attn_past_key_value = \ + (past_key_value[:2] + if past_key_value is not None else None) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode == 'multimodal': + assert encoder_hidden_states is not None, \ + '''encoder_hidden_states must be + given for cross-attention layers''' + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1: + -1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn('''`use_cache=True` is incompatible with + gradient checkpointing. Setting `use_cache=False`...''' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that + # the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and a simple + interface for downloading and loading pretrained models.""" + + config_class = BertConfig + base_model_prefix = 'bert' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version + # which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need `__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: + dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask(self, attention_mask: Tensor, + input_shape: Tuple[int], device: device, + is_decoder: bool) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, + zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, + with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it + # broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask + # in addition to the padding mask + # - if the model is an encoder, make the mask + # broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat( + batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type + # with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:None, :, :] * + attention_mask[:, None, None, :]) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + '''Wrong shape for input_ids (shape {}) or attention_mask + (shape {})'''.format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 + # for positions we want to attend and 0.0 + # for masked positions, this operation will + # create a tensor which is 0.0 for positions + # we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores + # before the softmax, this is effectively + # the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer + of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token + indices of the encoder input. This mask is used in + the cross-attention if the model is configured as + a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length : + obj:`config.n_layers` with each tuple having 4 tensors of shape : + obj:`(batch_size, num_heads, sequence_length - 1, + embed_size_per_head)`): + Contains precomputed key and value hidden states of the + attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally + input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to + this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj: + `(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value + states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + if is_decoder: + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache) + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError('''You cannot specify both + input_ids and inputs_embeds at the same time''') + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError('''You have to specify either + input_ids or inputs_embeds or encoder_embeds''') + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to + # make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = \ + (self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder)) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to + # [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = \ + (encoder_hidden_states[0].size()) + else: + encoder_batch_size, encoder_sequence_length, _ = \ + (encoder_hidden_states.size()) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape + # [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape + # [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer + of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token + indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right + language modeling loss (next word prediction). + Indices should be in + ``[-100, 0, ..., config.vocab_size]`` + (see ``input_ids`` docstring) Tokens with indices set to + ``-100`` are ignored (masked), the loss is only computed + for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj: + `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention + blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally + input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to + this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj: + `(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states + are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import (BertTokenizer, + BertLMHeadModel, BertConfig) + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained( + 'bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", + return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + # sequence_output.shape torch.Size([85, 30, 768]) + # prediction_scores.shape torch.Size([85, 30, 30524]) + # labels.shape torch.Size([85, 30]) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift + # prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, + # the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past diff --git a/mmpretrain/models/multimodal/ram/config/__init__.py b/mmpretrain/models/multimodal/ram/config/__init__.py new file mode 100644 index 0000000..ef101fe --- /dev/null +++ b/mmpretrain/models/multimodal/ram/config/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py b/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py new file mode 100644 index 0000000..e4b8865 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# data settings +test_transforms_cfg = [ + dict(type='Resize', scale=(384, 384), interpolation='bicubic'), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + + +def get_ram_cfg(mode='normal'): + assert mode in ['normal', 'openset'], 'mode must "normal" or "openset"' + model_type = 'RAMNormal' if mode == 'normal' else 'RAMOpenset' + model_cfg = dict( + type=model_type, + tokenizer=dict( + type='BertTokenizer', + name_or_path='/public/DATA/qbw/ckpt/bert-base-uncased', + use_fast=False), + vision_backbone=dict( + type='SwinTransformer', + arch='large', + img_size=384, + window_size=12, + ), + tag_encoder={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 12, + 'num_hidden_layers': 12, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30524, + 'encoder_width': 512, + 'add_cross_attention': True + }, + text_decoder={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 12, + 'num_hidden_layers': 12, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30524, + 'encoder_width': 768, + 'add_cross_attention': True + }, + tagging_head={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 4, + 'num_hidden_layers': 2, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30522, + 'encoder_width': 512, + 'add_cross_attention': True, + 'add_tag_cross_attention': False + }, + data_preprocessor=dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, + ), + ) + return model_cfg diff --git a/mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle b/mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle new file mode 100644 index 0000000000000000000000000000000000000000..0519d1ee759eacdad99df2811ff59432369e1599 GIT binary patch literal 51099 zcmZX-%W~wolBXv-aleqv%u{E2#rDxJ>>HB_;e~@QXSw=YU$Hn0Q9PSPW;{WRZKmYgt``_ii|MSm({^RaH|9$bV z|F$@;_p9}Bvw!;Yk0~in+vVxBKK=Q}aCyIcxB!d*wuk-t&p#%+!+v=>ua6&p{xP^; zZ?}Y_<->Bd-T_UP4-f0p>EPEf@Mg8%pDX!rTx||}3N0TFN5Y%s!|`Lgd04zIH|FUr zANV`uZ?eq_SUhgFyVP~wylu`l%7?=dM3<}O>v>tZd&||zDp$|z)8=ixI4@tpb+}yp zyqu_Yuv}fX!DICXLYE^S$^5oHo;RoUE;x>s`}2CgU)tFEc(~s@(C^9V^|<_4JS;)D zzkEC%{`_OMd^&C(s4lhimKz$rTW*evhr@og5uG*Iefx)$PG1IK=O4F#569)> znVRo6)cu*TJFM1QDMseU&3a4iDY)H8eL*2w7y|CZYw)^V`b{JPF5Wl$)#2T7A?0wS zo)6SM6nd+wo|YAE*I>d2rWDWG z!o|i`dQ`KAk01qk5Y~XlUK0Mo(SH}B52Wjdke*NXY zUoxKvjvMFejVXc<^hB94@%8Fbyyz%jbMyl&cGQ|>*e8L!GlRuF6X1zl|IyeIGCncK zwu|-e56ibUY$HE~Wv0mHb-i6a!AmBlZ_aC5n_uBDmK0&s>;;*!XlZs0_5*_AQ^%)E zHyo}`s@pABo9&jWDYd&_x~|4RAm(H@@vxEmQOC|nLzTt12^RNDhs9bx;Gp}g>LdKm zaXlW_#YvAx8AFem5W<}4==0`2j4u>9FP@fL<_)E&>AYYFoMKbq6H7#9P?7V69pTSE z{>d+I(BN=ffIAHP*vMx24Mkpx3cyYm`g7hqw%tMoUc?hPEw+b8qP?X$B z&2n*G?=s&e%M%lM`Q*kn2Y!ZuESNH(6&v5~zM3hJ;x}%Gq<$=l3sc?c@UXdJ3C4&0 z@{(Q#KEP;f|NP7X49bBv4-)>&?kb9BMr_L80~o9wli~7Qa}wa32|Rn;JUzn%7C&9b zCbLne6#Z%&B^A9vl=ah-%%5x{R!DJ+{t^n?WD$C52$eJMhs(?HAgcv9hjGmT4hspJ z?fWLt-I}Sh7x78ra)1toi!u``pYZW;a07T)9+^p?>J+%M)Q;VhWxQKDkl((h)Mj_H zJ2?1n>*pF_QM&9{uGS5(cs$hX6Gh|O_0p9wYtdrKU97mp(8af86adzH`Pq-<8$(7r zrwd%TNjRhB+kr*SG#o77mu@Gw2^S9x3Cx==%qxLSHT!O#ori3HS^jR_zn0(vaKF4i zguwtf;S3h{hcg_MZX^a>j+>>NqQ0_;ve$H9*IV|pjNc@=9}fpwDfq;W)RmIZPOd(Z z84Gl5usR-K!{mK<**albGJ9v2<0D-gkRu}4{W2U2tO==%T2@S+l@bpi0>RmQ->&e2 z>Gt9r?AklL>@MJ7lgs}nt^R8rT=OQ^Vif|}!d#)BC{Ka_C7GTo@4^8P2-Su;Zq zUG^`M4@AfZ{I-l74DI$W`-7t;;+faB!q!KX5G2q7#t5P9!UsZAAx#s9UoaK2zru1^ z@JxxCLW;m*L9;Bfb2vRM9q-o6P z*gRyWhl*k5h!km@8EbR<#o?Vg#bUo$9o*&2c8pjp!1gSfR(6bZH)i3jAQOGJluWjb zd(DIlcbgQoM?#`&yeQV9&#qL~O^br-s>FJ=&wX~Yp_wCu&6n1-t@%xuN-J1IGCGAS2uSKiN4C{&c&a^n47EhN= zI67lQj6I89GKi-WQzFxlzZ5ttnHelPNggT@LxxVz0nhMIdOkkAZ+7pISCa*;j}A1@ zy$meR&#wH%7rhoLw8SWHjn}lKns0QZ#djOzr5ekYa@x2w=jQ`kBGM6ZG@m~Ubh;bE zcaD13e-D>19(8~ox4kt<5uiyN5vnk}l8?DHMwG$e>( zWIW57K)M)b91IxjTLbWU>d+XOkR8b!9xnIQi>QYrIX}!nAg&kW#|z_cvlnd& z^8ywpjU1I8^|g2JR!^m?rq5D(wRk*cqSH7VuhT?mmCcMASFa7#e7x%7G_g9NDK55X zna5 z@V_`yDMkExR({%ZRo(7WF-WMlqr&|09+0MHhO*kPU>lM%p2;Vot11B! zBtL}s(5a1w;Yx-9Ned*4Li^p9_oL?T!}4>koC66MSg&h@OxVlhgk2j!fgG3&t^j^$Dw7WOb8^8*C5g zhh!X*l0h&hs?t;~e6eamhiPSV;Ije}6u7LBLQACuU1C2SjRH?EfRN&{Ze&R^WC!Qo z>n4~KS>~CE3)a@j-cBo8+w&xm_vM!E^??}(laQ0wY76E)cXMa>idogct{k-@#`lN& z$88D@0UxY^idI36AobfB;mc;*rli{)mbfR-q!Yk?MQ`PybdrsB1AfA6B}f4^*Ao8{^L zfFjJ7-xp6zMYpG`oIx5k&9>7VewR-3i|cN6T(YSk*g=!u(=sgyuxUjaX7(gnXA*P* zWCl@sM7L8$P0{?G5WZq4x0)rySQA}%NG!XQpIzWdaK4(LcsWm~z84~r9H|5?ix5nb zGlI#)q?gAC6eHAf){z?yT*XNNW^4LYHE&|*hZqkp?3b*A^b={$64XQb`#~M%=Dnl` z7z3;*ZN;KIcq-{)DnZw$UAhYHWo-;lYZHF**yc80)8{@w7*Ml>(nIW9WI6{6dUkkV zsk%+iPX!??Y0i6qLxORi%Mcdu*;qVdt^p0;>-?pa(M=?xpi(T%suTBdY0qA`^)1aLSJfqd^9UL!b%;oWN0jG{5gnJ zU*oxehda!7*96RGQ7*q5u@YurC%g62vau=<#%xi^9LYd3#J;q)i+I;R!mARm9UU|{r630d zF%=-M*2gAGUG`;sxwsa|%R+!3Lkc^mKwcc1-j!Z5&=rNZyLziGW(dZ za7|3X?1oQAW~ALBCS4wuJ5GSJYto;cyA+sVF=steLI4Br-bq93a6txZ$f-D)HPaUi zcCRG^(2?D1F*c!7wcT1IGEPF| zZhCH=tDFu)HBdLwBnLXHxX}A)J}cczHG#t`w6G*VEixl#n9xP-{ksvwBNobaC$-qy zlCB7RM1qd&KqIle8X$SpXuZB(ICtN(@*vCW>EbL9()QD!I=Rm5x#cgSAA{M>ejzIf zv6{^VKbL8~An{T&7#|Uf%Cs@l{0{XwdtnkLR~f7e%ppO*_ngY4Y`3UEOl6uWKwn(L z%ntODCaMZI)^~G-Yie2HYV)M3U0}wOK{|?`y?Q!eQ>|0>LYRcIiP*;Gp52JT!uyW% z2yfnG)k)sGMF_Nh#rTFKbI`BMq%3!a1Xvbv+HOj3o<;TIJh7yXW;~m)t$K`WtDe4h z=l{(HX+7+knMVRFS=XXHay?5PKAJ0RL&9NUr7k8+AXzYiQIPtrOG9dmakyBkM6?i- z^uuXkS28xp)~n8g)JS7+LdZ#f*@D?F$OJWj!}6%Ez@|IgAwcOF!s2LFY_Mz+ zhjm2|WHm}G{?_EDzB5*IColB@vVCmoBs`zWHX_K`g}MRAoR^60ZS;T-7mvN@3@o~; zE-jB02{y0Lpz+Cbsaw{-8P3sx2Covz#i+C*rhOn8JXD~T3vK`?YjSsJB*iGkVG{Ab z{OjLA>0keTJuDbJAV2Y8^swHPP(5o9M{AjI*zJ)w-81-7a6Dcn6^oP0P`LOwdXFm( z`R{mcJ3C)X_@2O9-_~`wLuJlL&dqZ^Zj6(eD8z8KB}xp^LW9r+982n9uYw9=fjaG? z8jn{u5P=*>wqYBSgxJR*NYdzG((p#{z;%e#pacdoMwS@~u>{(KQajsSQOQMR2`$t~ ziHc2+3Dr2|OK1+K#^dV4Hr8xa3Y#`<%@Ieeq{7w-DFl$Tf}WKSErvmR826T7&nxm@i1LsXjGtjS$caT~@$+wIgmNJ23UW5cE zb&JPo!+t2ML15(7z4AMaqx&d!P*~L^~5Jai%JyLb*l+57qHc zV^Y%Y0evlNl41=mDTG$Lcuj$?390dKUviMZL^!;tfPKlh+qgu(eP)EK!rURm`G|^C zyTvoUBcT0~1FHSvg;S_>0}0l(_c@IZudsa;Bk=aX4r^LkYcO~#D$;zHz(&OQ{`Q%G zHA;JERz6sD`B0)`C4mNA66-Id?BY3rV{E#3vjTjnM4OL0_K(j@oF&1LT*qIKs;$SS zK28W_EDjY-R>CDlSq4|2GmsB-c7!~H48L7BENP|S+BI*JF3iM{!jpHL>^O^yX#&Y@ zy<^_aVfE)jc!`ycNe9D{A(0&%kI@4_=`;=MMqrpYB%i*!12$it`NM{>JaK%HLVco4 zZ)%2#q2&Z)z&i8=OW?iMgtv{t09_!VJT0sVN%sZ!P1*LWqWV#N#v~r(tYXdRk$hEM z{{dy#*4dk>gq)l~;EU5juo)RY#q%UktyK%Wm6wQBfS3XO7`yiq(oc9=|H~y?m|#&V<`oK+2MI(02u}oc z(@v;E4opRAn83GkbvlY|XbqVs{B&hV6p~th^!33{T(Ibtm1DIR`Q|8gm>^! zu(O)onL7!mXC!CJ-(FLiPyjX{P!%7t70pcnY9}i+jSVPsIDi(bOr(lEI11Eio=;!E zSxYE$z@%bKMoB}~9Q`~Tp7u-RuOZ_Sv6GJbfj}XjesN?G?m>!Xr~WB(`-WIM844$a zB*Ya?rb774C7Hexc~WP;l*9E{Z<^QaWY{Aj*3_8dBDjbuq*F&r{!AcsJ6PxZ4?7Q}aWEC50BKuzYSAC!l2iOO%E zk+0f}RgwoG#kB;5cbH~EEx))>%Q}%SxJDFP;|n5W%_H}3c?8*CVJZ!ZYaVHc%}!EL z(M>fH0z-nZl;aZ{0c;+QspDY&WT__X)T>~zd*(=sBK&2`ye44Cb;7@J8s2Pr*H3Ww zYMT*i#fx__ye;mbKnYG=N)sKx5&OlfT4VKCYerE!P~++h$KJIH>U*kE!iH$Lt=Hbi zAmTU|rLz@<%(kdl(MWXa9)yKr{Hvl}-%8*R8+<(vb-C<{Skm?q`_xF|c&U;iOZ zlY5&K@rJ4DwT(2E(R^gCrUGUM-{k)z_6u$!fSWkS0khD=`xB1DsF8T!^t&`gaX~d; z?|?GRD#R2st&p`Zu-8+=-EI=il$HL%bq0>d_Hv)2OmSocTr12>#2qG)M!vJ0YW3 z?An|lW<1Aj0A--ygGMN3Lzod*5poGcbpfleHYJe5M7LUW8t0btB?rJVPad)nf(1w) zJBJ}*e2S5j>T_{Ue6ZRz2EXsy(F&*>Lvh7RS6w=U=I_u$-s653#ngPk(KJpzM8UM| zgQT6{E{X`uxB^e;60sA*tGblVPzA|O`X8k&Sf$=d|A90iFI#MP7&1;9R8d-)ZuaQw z%}7B>JaXl`G+uUF2qeIQi&+r5TwGan5^DsvFX%m`fUS)23{>6&um@tOOb7%~!k9Xh z(9Z0syW-^q@S}^q9wJ1fOjU_wT4zOMao9lb5+7%)h#S3PWTAz!7_*IeAYjHaK3t4H zVdpQ9NSPGOi3KPl!gP$#&c=<Ty+M=q7*i7d! z%(JnS5`9zAtG6AffEq>u<}K^e?cQu#oG zfttE=jn=^fy~|f|qzngXmLZjWbrJa?@Bx2*kfL%=L_mQ#ki!xBjRsyxm=H?kFsVG* zEqua;uUc1dS+{Homu$#um4z=uWpQhKr=MndXYscbvQ8i3ykDCSI9wjE5b6I6h=S?_ z574+K6_{LAa`Kdl+OgiTq(!oD=SF)=FWy>x{)YCK2haK`N_zT@_Gu?oLgP6%Lg(>- zhQvoKE(%V`40ZVQ+L=u+ubvE3?J;uZ|*Ay^WT z>IZyO#q8k>5k%nSaqkB~28=ebYnG|5{F6rdDTOl!%6>V1P*;k~GJN>J#R+Wq`dTNx z^9D6B6ie~L^XWO3n(=ChU4uCQvGEL!y$9eK2cx8HUL4B73Nb_M2~UK*Rcto=t)mkF z$ivrct+Rxf*ka|MR;)w~swNB62ZvbKGEp#SGK)av_=3HHIoAhtvKn++{|YIy!FNGrB-V={1Or0B)-j#{{ch1X7T0BT{gS6KYETI(03oOKljNOJrmxq!yeivzUA zCaPB&5qi%mqV%i)q^H;cYnMtly6tMvbew5In4kkr^QV_E3z`oU8($xW%)07y7Q1fh@<0lc{* zsaIE;ExHrkHw0H~?z>NHyWnG(kv5vTy=(a^f7P0l#3d zfJR*;Ho;N8*v?HLt8ex*d^gJL1~hZzL;Zy1w_uvXW2oKwHld|sfUFNrG5S$+Y#bmY zq)_%qx~hmny$9~yvT-<~z~L@JZYEF{^8$k!0e~A*437IYg;iC_cEKV*2dL(O<{sZG z4=VodN-P3u8rn#3tLu_bx&2Im&d=CobCN8t0n?%G@KR`A02^pLd&q2hxoFZ5C`_3r zXEkX))9r9X6Ip}j{ndtM~M#@haNNzXDpVY`sS*fcAzTdm`J;kcLP{gc`^GjMg0^LghXs$SzcOB7>FFAFiB3#W%usy zqE-o*DOGr;Tr9<1xp?Nt$VmzA28f9Tf3GhC<5V~##5i5$jt7q%$W1xbGkEDlhz;8{ z@%;)f*b1 zXK1@EIcS$gr(OYgiuoM1{4=tsDM>!gX+dOGKw~fCirK^V!RvfX=I?xv3UVvva)y`8 z-6h=tTTLrO*qL!|l4`7o_uD_juVyOh(Xt|0ED^97%J}rwF2qmWt+hQyCH><0I)voe zV*BgkGp=wWJS0k(o1i9gVp%()FwQ6p^%%3foY!5*h2@q%0^r;P!riS|INJ|i0XQ{1 zVC2zqsBqJu&e(m8H$FM*zQ;%WGj?%X@mm%cn>*`D>9&WPszdeViuxFzx#L5Wxj8qw zxuLHcczlY!UhwX(e|MRa8(Hco>5=IGtK+xp-;rH}+0Rv(sY){;8e3!-@+vijFeBlx z*B%b{VuRZ~i=tkxAflhBEqx$~$0Rd+T%ruHH0K3$f;8#nblC@RFb$L@+rh~!@mF

p|@%2yJ8fC&fE6rI^}E&z%{pzvNWhDv=U=Crk!T^EV@L>Z9g3>wbkc;?auNYg1;q%-<1 zkx%{D72@Y68G#=Wxu_D^B(Z0fcGatHpP@06JT*g^GS~&%L84{Kq9}w6A1kwPdEkHR zK*-XRCFpQGo&N=&`vvX}L0sQBnx}J38x8!q{QSbX8APDN+Cp**S+FTf(2n&Y?HH#( zUu)t6*R-lH_3;(6eaKT$uYkmyLp|rR!`wCqZd)c@=&P7bj^E4RF?jqLM%>qnR{ROT zGAb%?1Vh4TONL;<_T;@MlyBya<&by=1WpG->5x!qmN+Pn4eXQCf&=;})HsbW-DU_$bMIp;I@7mQTu1j0I8@tG=O3^2?Iq;{M?E(MAeYvS zYnwt_D^2$r(IeV+{nA!G#7VTcZdXk=rD*P=5!0{{9n+00!fk}Ip0c(XMV`D>FhERl zZS)@xngDx`xuM2iWwY_)VYBA;eHbQpuw3pOH#_s1C6&Oa6U`6K@$6_s$Fou z#=8rsch?jxEwV|d)XLD0_5#2)sa_-N!V0tE zLT8%LAx9IWuYEGYSqW$r&_;kc;>w+Ov05eF$9DgTyt*+j6S(tRB$M%Iv(ti+OCfS1Jfjk?yAfYbxpLFz;kaz-C&w zaY))PyEtg>SyY8dhFv?B$cnH#F+t}Q>A1wKqRmCJi*>6fmDJu_pSpW1y~mPWjvp@n zvOGLqvW)tV7g_)05lddv0eX+B5+nYqXNo0vg}kFen08m3*xpU$) zeDb{+J`Q)6U%%XgNZBrbr~jktqfz}@WmALB1bLST(zFNcIlSrNHOvI#0)V0j#+2-g zvjUp(O%*2B22KMwi{=OyD8~A41uVWi2W15=t`aab(ip>QC#r)Er^H?0lL9k{CO2Lo z_;iSK{qWsmM0Z5eI9f;_k=SEPf6^h6#=Pm&Eu4zw`Xk741ezPj+C^ilrbUI`v^aR; zWTnXWYkkj#-G9Q=!>~NDAXG1()=U-Hd@s-R++YxJ+&YD6Qt{GHIa1=Hs>ZXy6WwNX z`cGNxV}mm1&l*_UTl?UCHpz!N9xb>Oqk`W#qZJ1wTK+PFhB|3WKWY?c?f~Dv#Grld zOq=6{I~z55M{JRtgwbcPQ$-<&q{Ax*Jhzt-!ACd`Jh2Y*_JQtOFqwIaa9;0}>%NtE z)_tQdw++5pJzCDPF$YDwexGEnhPfw^(;LQ~9{c^NB~fMm8Kfd1RD~ zvZn!%w=;)WK{5w11!4!uhfI||H)fGVr~v)-m1~WeAb(Mc8$vx;t+a|~`_!O&TEaBO zdS(D;t21#D?^)HOC*%e0#iyiJB}(Ev-EM~d_%|Y~ihqlZsIkaz<{%M83Er4vgt18_ z6jJZjQ(OaBni5?DWBMyN?u6iDmL|^9gDqN?y=CWz3}B{3kD1LUTpn3fe7d-ofFsQ| zi_iDZ$bv;5a~MddOOBC$R+kuGrVT9hkBC!rM2Nc;o%b^@GqE`{J7$S=3TwqMn?-%C z80Sz~+*W`mSFi=bwNE6=ZNVyE*WuU@m*FOt-;^$`<>c~-!P~1HC%ga^f`)>7o z3J`mGLO0L=98GsMzL7nNdSp1#eR7ayYfdP!{MDFAJJ*kKsE)D)DFdxKKd+k()NTPz z;J5x@{)-<~E2mN=qDfSX_QDhRGFr4o-vrn*-C4jPFD+Q0RBcnvMFZqOuNUOF1e*#@ zHt0ZQTv?pBcU})UQ2+;*R7}|EzotqS5Q+uj>4Ig5)7wxPGh6UBEGPdNPnN~Z6^1jd zSeaf1*sL`DlgJ<~(K55ouqDUQqMjm(^xJ>BIJXVxTJOn?a>(Bt!H?12SI9a{fzi|D z7lK7t;RoDy^z^awC0!t%Hl<#m;d4`2(WVuP44D1(ZbangF4BCxCT3_JPgJ8rxWgs{ zMWIg;qgDpWSy2kLD`2z`?Bk}l{^vjV#(ksrSP8#hVHRZbm->?S%83UaXu(KJK)I#R z49~9~=_XYLa{bbS-3A6y&+8D03Cp~@IzR-fa~^;?o}dVekCU5Lo=CW2DkIV?c+{~k zlTTjpD(#>i^vV%QzGqP8=EP$FF>cL(LF;zJCy<*>V65;Kk*nYmqwpKquilxVwO)ul z`V91>^AiJ7G0~psb6(l6=p7;yh5Z`@tyiyUl$Q4Zp(7%!<7N0aCC7( z()U1ihEyNBL=QmZsws-g2Mm!97a&*)t-t$(B-N9r5d8qxiE%R`fRJF}neYMwl)gvA zU|}>_7pHwJl}8m|c6S}I^IpOeX;gwdX+|LYr8GqLNe}Gm+R}Pw1ea(1>`Tc-X~(ux=v$0 z^BQju*XcElEvGi+3lxIm7N_w|6w( z;P$13Vg`vz0F`;^p*wO?AW%c=fC8 zfeSHaR86n3(K09wpE-q&Uy`rJ&_`Yf<8f2BCfXupm#e$rt>p?z;NitL;~Ld|!dWF1 zG4N!?>ilg-wmkH3E2S6WV5Lfludq@_}eKprgi`1z?i?Dwe zie!@&1zljIl0r`kfei{#eD|P0fo1BVy%w!al>W$0dN}4dU!0G3{PkWov%5XNNjG_X z$R!AASiAMUGYA1|7vTe=na6A(KVf&2Y8GPE6iA($%uO}yKVBY6$?WY6IU!FKv#xVy zOkolzp*ltc(Xp`}i27dc($7fK&a2+D$Lf=3ts0CFLo;NrIyo{^270r?8M0O$qYlFL zB-1UQUc=Ad*+)Y@mKxR&9Q)iz=tg)DbLE(3Xn^AdTnorMc*a?Pc|+3_89sk3vo-*< zb&1Re?jNi}1&e~oH)eCOOX8dvI^u7MQd4Zc-bRBzlF#J2$SkP8D+0-FQiYja*75B# z0ofMAN>{R6HT??XBI9?#|7M}yeMu{W*LcgCgq=dai64Cqm@4~>F%#{^q%V)nBny@p zrISI9KVYoM!%;|Pg`MMog`pYZ55ko#GU^yYLLV0Hy1{ziPDXB+*^r(6w8YoJ2E9qP z+?eQ)bCx`|F=d%Crd)ENzPXXl^REH#ut)9oLa1W>h%5dcMoTLMS!`6aLO=3`8Xn|2 zQlV?WZHGjnInE1U+J@s>u3cB?-1hHG%7L>6X-9b6QkW2cx~0pp$*0-fcRrekPi0q2 zoUZJU^_C<#tSfSRm?}rBDNR>uVKxVRs(D}E z*ZmC$#y2ZZR!}Bdmk9a9x;=N->1mLkh-f=Ne-mtno8})wTwdGX#GRBttZxili*4E@=E~Uy3&c5gP+`IgDdubayg8-|C83JV46_gKpT@pN^S4u*>Qg5fsMSB zNQ&g#oLq$E_dE}b&Q_+Mo1Yu5jK5=xC==fpp2l%`yx4xWEq_`-nF4x0V;ud-jX35H zpW->vfO`?;MYfESy`M7UA%F8=D9NyaBoIPJzWfY-nz)Oy%#;Mj!u9?S-|IIR zNZY^EJ%|2F81&p=vpV3onL;lS?Bx=p7gRVd)lj`qbbEki!+lwDD4|Kl;TP2xHcuqK zUc{Rbe;dzvLA7w`U(zE2eV&Ie8$4-g+F!9u2SD4e)o5T}=S z2H8i+mi?0t-YFds)VZ$-DTu*$3#>}%nN9vCSfLW%j@rA zFX8d7SojS(H=H3B$1=(RL;@xJsO_I`*sL6`I%T~v@O)WEHCh4#;*mlaOCu~F4{CfR zjBb#`QXAmrEAyaC#>U|D8GSmj#{efs83|@EmnM%kK(UFx49;g^=9ROW{*J?mR3_ZxPK@q+qjI#i(B64N+t+~<=8v~ENU_k?uA&&$;IrHsQPzsfxFb?YnTdgO82XW|&Ow-V~z_4@hZ)UZH+9_ltV$+MQ7weSPH7@fZW=%t~2ZO?Mhsh@5x=9d*K;fO*_jybT zwlI7nG8K7WRmc!J1lx`aeK5Cz0eE})voki)fLa}&>4w{x7iB^+tw-?QkSsR{+WzI+ z1_)XF>Poh1T%{V#;_)=cpCdj9?Jfb)*!(dtZMQkp9pdScA?ugR2PH6*T_>%dOOG@!PQD^8c zD_8;}jg)49?hvQ1X#vg}OBM^8Bt+PR&-KfvW4OKi9!Yy%z~TdM1*+`f)()opRh$KAEEXH;qqMpirQY*#hY#f zG2VCLDF`QCzZ^jmTlk3cuR#l)^wT1eCpCf%F9EpP1u*l3Y)tPd&Rz^hX&!tuQmjO^ zu84$vUZLQQuiRn-V{b0YJfV4|0Mn;NFigB@uvk45jGvSZJ+UpvTu#qWtui;P18XVP z;~=*$`xN~Rgsc(XYyo~r&Q0F}du?Z@UH^ zg#e_;f;|F@0y|iC`NZX-?WmprSqCIcpa5%_ztn^akdh>HE5@HNTmV|ZT!OD4XWh#2 z%!bPelctCA4ASqsNmU>Ebu0dc03UoQUvH9p#@oG>x@FB(nPu~b0%T$4O`6*d#33LK z#)EdSYv?Nfo*OoN{bDqxNp9Vcwt)rE#cLd8-wV3m^4-C_>KeJI!(|7?V(3@>d8Y!M zitZy4N31y#EXR!?kOQJI(nYfy@MW)WJ?BNmUwUCvBugl;@VX|eyR2MPP1dgi+#O%a^r%XgO zTzUOzmOe?$VG7_%hgGiwF@{xHm0~%f8KDtD@fw0KD?^EQ;|IMr+FdHz#l^OM#VByQ ztUfY5w_$0M9O+>a-)!CbAvHbY3M%;cSj(_`8i9|}lid1afmA_CVl?5)jb{ygy`4wl z1OAHQj4-P|t`wALLQj7eUFutY_Dho+bi`Df70`a&G9Ke#LE2DQv3OoU4)#cA63}qC zBNfU`xO^LDMf(zizqfIV6%Bx5XD6&jRLGL+R|bB;+go6PS5EFcPKh6Q+AB;G69uf-8NPo;u1AoJD4CPNAIq>62CXGUD1@p+^HxbPi+cb#V{C|X1^q6t zU9r9WhdA}!eN71AXtWi%`hI-YQo8kzYnG_3rNVw@swaQG#wFBe({1Dbl7K)9E}seh z_ycq36FIc_*vvmGEsNwQvR7|yhPLp}&GRd&#Iknub6_xXN+F>*lu)%~TmZRwk^TC> z32vpGadT96WV97RAwLbrG|)!Fs1`G~QRUL)%Qs&DoQ(+q5lo6WrFsk7n$m9N-3l0``p0lEbZ=5Afo$BpI4)Rf3mZnp{f zr4n?E>({LTAa43z5V9viutWI9)@^zNx3LXc!F_ zg54VTSBc?pc-UdWBSF$E+9~G(3_=R3C{|_>^>_Ajtk@Kp?s?CyOwLe@2;$aJtqKrq zhkKs+R8Z~n1d`i{L5W)d%PLi1_74g?gm+8|=n~6DMn*KvI3(OO@N%qENE(pOO%1xA zP8hK9h$P<>Xej-@_=v@61V}X$m;=Ro*R@x9wjUTyQtBNBbw>t&gQa^CiivCrLu39y z+fC;QuNMU-70n^owBTA{^`YLLs}HTB;hs}e$FI+|e}b_=6HM!;?nvRRAV;ht!CGhq z#rcFPSUQxQ)!PK^M{Rq1net#?S1d*eT_kfM7ja46WOIMz@Ac+MFK_vc`PW@xS3#dm6M0QZ0LaalTlP#&2K_s4*)R}wF43NsnssS8g1Gawn;K+H~MKAoml7yYLggRNeX|WSByM~D_ z-kEtv^__~wfg^As5I3IWBlhRon7=jX%P)QoNV4@=!p{LRRUP{_=`lVR zdz-Wb8=d1k5=ps^LnhHHFHr@DosKy#ptGmVTC}<`u2|TxTQiet#ICE)k5mR?G4e)! zEkCK6M9~>%ui@p}5(~?$keouM_0NN=-MW%CS%x*#p+5+>7Wq^pr>BZ_uRk zkYzV!^%JHegHrEUvt=goK%2G7FIqx#GOWrp#7`D+F2&dZJRMd5ePq*4MrxpHeXO>=C`|Ilu;TS>NE2%(}sHNe~&$s9`N_XJpH4e}?HO6%+rQ0Ow*ySUHbR5jSC zn%#$Q4s7AUlH3&0t)A<14I}p2n9Lz^UXc+?o$jKQwznb;(mZ|hv2J~I$k#N_pxJ4m zb@Suy`jIJ8a<3^K#WvU9)YK$#KKYBLV89~CxU(o-oF+v7;ZqC;r0`=|&Kw~qK%c4U zZCyPUL=BY*@mQ7T!%U#$_-JZvkK-a+;&;$!DIsfrWIT4;J`kP5y#=EY&D<2g9#Qkm z?-0&{`rwnVy$SN_>r4*7n#}>48u|Vix!%zFnnu28{Z#5Qp)%MCwU1xdPfcqf?TMp^ zpvY>sR=U97Ie(X12C9xy$+zdiwnH5NK{QSoErz1)?={qnsSMq(0Fzu-*Ju^D3#_(j z%q#0OGa<1Y7krr(I>8?#KLn^ZGjI#$$NOL%P>sa)YHIua#M}BIMSt$L&XypL&fQ#% zueyzt8Q0Vri3qvrNT701gKXCrR*m=qfz8nyI>DDqCo(;{k3_8b0*HYZy98;kJK9x3`opZF~va_9V&AS84&%h=1wdvU~xt_pzZEC?6!)> z{HQs<+$OcAfP41l`~=F37kaR)V`yT??OM8B4hT3=+k;e>Z{71q3GW^$HSbT*@`8*J!@-fCwwzL|4)TbPm@cwPxFB6FE zY*v6{K$mY??0UG^;_WP)H+h=VjTs~MH+1>AK4?g-Pm#IFxMh+s^@b@5+If0codX18 zKF4vFusQlro-yt7mue;k%Gbc)2aKMS5gaF$m7GkvJd){0`3!MD4-LBKvgThCt8V7B zO@Ix-EB!U72Js=T`_154G7&8cG|i}s6+fhLrBhxc09#hxf)fENn!kE*2X%z26LYE8e$tQu0WRn-IwD8rMuY~%cX%w-Rq5HvR ziJ2VOAN{@3F5X>KmAHQAQadbo7Pux&SJ&6y^_4Q^t1sa74Ar~T(8NYbmV>3DdKy=K zQlLfnxqgP1VW|qXe!dlY_D&zdV5_#f1Vuek*h7EYEK!b!8g54nt$ET+$OdYWk_{2Cy!R-8<=KdNh5OhQHpyJP?K|1+b+a|?j9gywSH;h5$ zh|^q_C~4V6$s#ux6Ex}W8({G_e;Yfexe$vaTDJyP7jjb-(NIJKmR5L!6e%sp<&E}S zBTbZyjtSPqXq=o35DI-MzTgOzb>yb1)47I|R%_8Gr(7(8B|R>ZGc2INM*$XwF7|NA z#M%vC@D;@Qt|`84jSQ+FIj&erh5kMw%2a}Xts2oeC&(y*?$X-y!8@ph}8R$W;L6itOp zv@qjCW6S6x@=PE3=9rn98pODO#Q}?Dox)OE7pd7((F65?;6WE!H1oFae{U&D{@vT{ z^e?$A(FbPsn1px~RZEy3DvBg53dl*I01@#yWd>9$?Salkxip%sp0KyDkaVBDUhb8d zYvul^n6}E#i9J<;F28O2+g%>z7`7bW`>(k|MZB4bU*(>px;2x!W$$$| zqicSeb&WQU#NR$sP}Z(*^noa5{qkHE?w**wxx$>NLH8b>An7B1_A0&L4hCKZ%=15j z$ibSYl3X{C>F25h8)#&C14J-`%rst026Rl;kLpQjAyC8u5K6?+pj!M2!+|I+N~Ea& zN`b+Vt=hYTglLD<2;~seXp3BoNmpG!ovrd+4gi1y0M${py*Fl3TY~@&4k?)PXlPbI zQ(P&+h-HIv19EATBv1%rDv#jcdzS%Cr&Ai;^OIM?Zb>fd|JY!?Opt-S8_3@HxGG%Yni4 zh!>9crVonIoae0Ie|p1Ge~))t&C2Vo#hbEm&wQ`sx{pp1YhpF`al$;C=tRj{NBCbe z$f+<(I*W_UkzBY3QtX;-lsq>EUrrV%n-f(UB5d}OE)s4suP@k+W3uDY02%X(g8kxb z69as0weuAWJru4?uZxN|fQVm6O>kARQPMZ%2Lw(WLy;XcpABls_1g&U|8NzzzSP;# zJYRV6wEX&LL)Sy2Awc`=udg`ZRT7sh+PzIy^|rPL4^41t6+rgXCO>FLE};0}wJ$Q% zK1c#*pH3vh&h*ne=Ji?rVzX%#Hi-e=T#FCHOsbdT($Z&?Swe*wDf&xvISg8HrBXSK_P1p2p;4!+>6@4T zfdk0g3h5IKfe~r^JUCvS^6q*d^!4ePbO@Cf_@mwR_-oHJg(0Fe?lD z*r1$|M}nD>)QYKyQMv1$OW-%k$_a+j?#hFvN=y?rOHLxpf($5iaL=HS2;U?gJEBG9 zBZen-5o**&`kxG=CaFoqiFid=Q9O&K&5+?(pFzB*bBQ zArX}kbK;62qHAA^Z>{c#-%8F{jJPfU z&SZVXOsSm{B>0_Sg${b^sE~kBOVU6B%vDK|m(v7&#T5r1iJ~!rVsb^D01rY?r&yl0 zq@qg=CJ?7sy2thM3<2?Xf51NR>f!HkOk-aaIextxeH97~03DrKLPnV%Ef614^K4!p zD8xbteP?K^#F0f!$v7;KSGzd9n4BJ-_sdIOD<1+!6Q33UCxkfDk zi;3l>ka}%MH0?6Jj zDVi|G8fPRd?{6X`%R5g^XOFXFOGUD$~hJO>m2o1%CLbZATHky#E3zHs@b>#}UvpJGOYA^9u(EGN` zpj$yWyIv}hW2`mmvBl{5#GqY|uzKDLp?Xpwi0dj%;n`f?LMMI^8A{@jK+mYbywteP zWCwUECx3i)5J8edB&Zidrt^1zbs(~$p+ut$ZYcn>0-ESdiT27hO@SrdK)%(0D3z#r z2)D`e!h^7#O302pB7q-yW2f~~3`~SM=rCj{P7JX&)-(~037^zpawCwP9C4lv(rSGF zlMsnAEfF$6Cgk#^AxO>=x)9tR{HpOYA;oEx5vkjl;K+Rx+ovDvt<%Jx`STzpV@t>e zn9<>338s&Zvs9F>lM@jw`@@e8^vpoFJgcOk%`3i<@i7vv8($Xhs&wAPuEOk>p=!P< z9FRl8im3#UKP;+(%cFr>A|L#xgi@(aOx$g`oe)DObim%%IgeQS^$`J@1mUNp|49&L&MO=ezIk?z-rB#Q>S_6S*vF{f(C zQQXQ9h@+jXh*~*EqM`t;E8#*=5fe(yUXRFKP@|JC6z41isHBH}$UO7|*)lA*ZpbigWdMC95{$>W%RZ)k|hm=u)<8_co+b`yyky>Ky z#pR^agdQnWJ7#w23{4k@PBg^{fedAjv#n&=G(*abhh1O1=xKiQd6UdYT^pm;jjo9F zBUw7WJJIp9A;=>LIne;2jrcGX)=b7eQIO$lta8*un7ITt+1FqEZSLU5$L>d$Mgg55 zH2A>|RU6=t8-e^rj`ex^WZ$?F?T?y}5WY1}C<0k{lSJ!rQuCsrCU5pG$}mAv2_fTY zLJe9S0QT!@5OuPku56lIab=znzF(D8^2Yi|5d84r4C znV}-@h_Eb8VXFPm&NJufLyuw z`nKPMG^`VKu^Je93=Q0Gs5RD~)8VnZ8X{P0>gH6hgmi}_BjP+2K{FI{WsV8jSrFik zgIf&|e5_GzXAY-4%?1`+?rP6q3*#2y*8Wq>G0tj1eo9C4Lz<4 zN{qTRc$1Jx4-=DgUA#LrnC3gvI6optn z;*LaLQt^kF@OlgnA|-t@gtbr1a`c@IQRJ{P^e}8{O?gm{`P44=i(>3>qXMW+5~7ij<<~qUqM#j|O?=3HRW-I+LQ{RtK!+(9}xP@hXET zgp2Mu2V+bnS|D81@4q;g62+PGfNd9zD#SEL=Si0gtgPlVV(n7z)-V(2y!tZy}9K?Fux==?mm(ecrZ7)Rj+N92L2HR|o5K zh(vCH=79h}veCHjRbC;lR-&6!3*pO*M<@J{87rJ~guXz;FZR=$N~oERAKGj=Xwj<*G9D=O-`44NERjQ#}($$H{RE5rnxRdcinMFw)}{z zKr98f1!(H6-O5@yzsmgv;0LTLIYl1un?5Teca=zfcZD&H z<=0?nrG3TSL{kZfa_**5n-Ra~Bz#1t=Zs0c&uC zW&&%YHy6cMZR>6VU6SkfKH( zfKi_k@EcCBc@X(t5iAU!E?4Grx#3py>^)EEQ4y5B+?m9E%KPFV8@gZuM<$>M7~-Gr zcGt`olV2VT`U;RF@&U$`}16?HH%e!0q>8sYjuMwmGG4i13}rxy~ke0IN{_o$qJ{ z(Iu#t-_UElY&Y!Z{1W2+<-|`Qgjf9aYm`&jI*}Y*U#7hYgyS53dl5CfA->doawfFv zmYF4pF?6>Lf-kmTO=t;BW4oFo_6B*P0$<9#EosBPQ`S7QHNU_!_kdk%v?tIE_LPZk zO?L=6iY~@YIkrkif)9ZrEvAi+yh?@dg@fH!I(NQ%;2+8b0O>H839qGiH(JiuR~jDqV|8mNTf9r2eRQle5CX6pqQBj&T-PZ->( zGTLVSqM%KAB5$egWFCmySF*8rgXp>Zh7U@<;MdD*3;pxKcgAOD8q-TYcjdZR9mI_5 zn=FbCFCY9|M#&XJbqUuyZ3uL{np_o)iG~(-Xto5g-xx6XL?s#M9O~t*DbV@39bQ$W zT~?ap3)TsS;TAQ5aq;Em_FN?iWvzz@b-#rWH%HI{aX4YT?E)U+Hv)+289qoeEb7Kp z=neTP2EeamG6>Wdk5>>N_Ufy#_oC0OaK7&e?OB*nI%vV2$Oe zY0bi62xxI3G0#x)f?V8E%?NwQj}wqdMQz!$6Jsmz>I;&in2&Et$?#4r9qYfirzts6?8z7k+o7N=i+n|XMv)O zKs^zZ)SNK#*~4(MXLds*LLL~Z)9W5vkpe2mKU=1iU|8KrXqG zm8;6)!8cT5+9$69j`6)^CjdCjW}%yij$2=2&?UE5-N)y*ZkL(@UJsy5r$sN(3i6tv zT=o_aap2@dwwyJONMJ2#84yB3J}8I-XNii`yVuLd1Wl7S1vCc+A>(^^ME}g)5|<*u z;y{L}kUs+9)EZqlWeN8W(x&tLLtAEdSZp#rN@sTacbcw0hDkWM%kvrQPP4&Eg`4S5a!;!z1eEtLU) zwwBQwD<->CuWnECw^%V*ag1T?$bz?_P``eyh_d-%*8lzuOBk<{vDRVz`vv9jY(gL_ z_R_xA^V^d*UY(4Pk6rHyU$=)ZNuAm`tz1>v=<591f5Rh8RN_xMWD(~V((g2#$R@5Q zM|>({3Rz=^z-dZ51nDK9lkNu4OL-b#Fs1Bof%G>ZF3!Fc5j1G-!U zB4ZYisY+=Q?H8X*!ltEoN(|l(s}~3`;-(h+Cah!n_0HC}&j@N*1m*8vk~)ESLK=3R zay5Lz-%iv=Z%l~ztqk8W8M+o`1uQnmc=Yxi{gRMW!t1Z*zT+V-s)Siea{))OJ!#rS zYZmvByxb8DvN@1F5Hrx_I#n{$s_5ntccENwPyO|?BQ&(d`;I|Ju}mI*DY4!$pmHzO zQDr^GRNEkJac?McG7vg1yvsDK#BdJEXgsG@%p!6EfnavmpwH=OUbl4uwL&u>-hA=8 z;7imILo1^l@yN^^tOG0>tLEH(pd`!&KcCe+JLUjYwC$E)Y@x2x0+}_U-VC@HIe^@ zdHnUumpso9s@6eF{e*V)%DEC5u&R=GEU+OT(cUx(YTBt5+iP;)Q!nR5xUpVeaZOsR zi42lVk=cQ(-4Npr$72q2RAQc?Z%_2X2c2)doVJne9v z!+^9fqH^J)W2G5^tpK-{;$cX7dA2|okD>U?R6g>aK^)7#y`p-qFHyv0zL1K)Pzcq< zl{Xe@H2b|_8n2}1@*+@-QlA9)M&eNkrRP%1$b#}sr z-42<_o_BP}L0GFckUO1ekwDDlGNmcPPOhT~Vdv92Tacl1n`c^PnfkuZAy z$A)Ec%ik>ch^r8R-+apt>KX9s6_8895UF2jod#^S;iT6Q(~$HSHT}&v$VNV$UU5XF z;YlJTpUwlr1WFjAihUDtse-)`$tYz}Z*rn@TdAKve4PsT`iM{yQl`}P_S&rFFZJ^p z1G^9aVXQR(Q;azqs^qCbvwDhD=7frMtjspyi@w$ClxZK8q*`sfVXFbWy4oFtS1B!S zS%CI02al2nMs1ua>Rb)b{rZ?amUXok_&&e4Q<^S&&aA(DMe1R}Q+bY26%xyJiVoJP zO~SJO_udb1)BbxdG7sKSxx-SjRwmxGX;0Zs1IO@S-~8ZyCNlv37tTHMWMDfZo12I$ zDwxTl#zt!ur5=0{GPl2*?(Y!EDA~egTcDy!&rdBN-O~^(6A6xq6u3RvFf?6?C6SVv zwP@wYOO;Vi0iC>5CMiY)wL8*71YAGlaBgbQJz08dL@NMB&CHGwo*1>@#touH60sUZ zHv&2}tNS40wnN#8-2xfin+Bnz<+ritUA!Kr8g$Pn=LZ@3B~T4Nm{P8Q{>Q=Re8&bF z;%Fa`pWgxEnkUgg~0 fuN=zpH~k;}LSu{XD8KFWX literal 0 HcmV?d00001 diff --git a/mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle b/mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle new file mode 100644 index 0000000000000000000000000000000000000000..4abe105e3b347ab63c1dd8ac25977853918635c5 GIT binary patch literal 50796 zcmZ_1$!;vo)}Cjegp%sGw)ma>|DW^&9(bx*!O$g0uy4SyM;;ik2ZCVO1Hof<#kSZN z`@Zisdv3B=>}I}_ipb27Zo$vHR%DUq$UfMIUu0XEkrAtT*Sl5>{nnwR7q zLrcZ_0Xy^Ijc_$jBV3*Mo??Bp_xi-1UUxb!H#t8}P8Wx>+596rpuHu@vUu*R+u_#x9a@uy>Dgh-{JjE7e2n){wk-11)irC+b`MMO5FOTqt6`O z)8!*ip>v|f?MPgm?%<8=;Wc~R^?b;kV0q z1>GlqUf|Pkt?qWY9*_B6x;CHWN%RxpZdr}baXYP=U!BCMW1X}5QY`Apd!PI2md_z} zhW1^i3+vtpw?oCYK6=qy%%Hn; z`P4~1!w*!8%e8Q)FPeQ!73UgH;c7MB%rFAlt%30I95&yx$y1JIQ~I=F>o9EVH?sNT zY;KVAPy01>tGCk-PlrNh!fzPg zc#CR2x{-IbY4B?K0TZ$o7S6)UW9W{tlW~^sYaHYo!}*5#Is4e4>uBh)7j&q2`Sgb@ zmWHyuQypsl3O9_BO4i~P>Y>}t+Gp9>bm-3YH^b(K4c}@Q?`E^~Ft5ClpAPE*-UjoL zk$jGasm9q>IMmU6`X%n@@w^f5c$l*N$XF@cdvr@`7flTVdOI9xbY|U?aIM)@3~uxX zPZBPhkNF{Dq85KQv)9jj@?LE|?&@4|LO;}au>AL&iKlAp?8j3+U;3K8&QC;7VI0< zz0LFI*ZaE3{$bqLx9i_FY~6;IW^#796Q>6Kp1tm5e$CD*Vlf|Ae`e>O%!sVbOI6vE zT_VUBqEDO2 zI?ia$^CB+A^2u)8y5T3hVG^On^kJuU=0Ms%a~md}!xG=Ip7tOt_IT6sgy&-fbhtih zw|@Ew+3rCxJ{RW}89e-Ae4NggH09&UxV==JHww?koNTu(AK+Ru)bCKN4VlcUWJA18 zEgRXTk-c>|q1WArkN4rG8@i@2rcSb7ylFk&9~aw;44C|xm#Sv3qxqpat=Jsvy)m2y zvi5GTdm4WaF<_E}otd*kwQxNynJ9HqXGZg!sIj2uskt(6;rpc+_@J8~+*zG4et`}M zb1U)Jhs*15bQ*8l;f`-w)u`)gcWz`2>m>V{ARSWZ-@U|SL7+Uqv zJkx-AKXA#Gw?=*N$LX0cv!WA@kMqxqc=zrnAi;_QW>YmgdJB8{(YP}dnv?qV_pN-8 zPN@0GVV0u<-bQ(6^J9Dkb)DU+*nDz@&24Qkr2VwREH_D;XySOs$SjXeJQagqn_v2@ ze7x**nl4rUTLbwS53l16WAbZHn0vB!{I&n$M5>146(7dM!I0X`4n3w%Jp0V1byOuDO7g^&Q=22T zxV=@pjO0@iP}%e%|7e8s$1sY((Ah) z1qP(1c0K=`Wlwzk2sbaLK3idaw!!~h)8-?sxUOMM!#}MWEv~HU95cbf?n~U`;%fMK zw`o?%W;FfsnMD%{sK$3G#rctP9M}7M7+X5?VA?)IjUVOaVM|g>CfhIR-T1H^5AC#@ zyW$w-SID%)L&o@``|Gp%vtCS-Vd$oJ`#nm5uoUV8|0%=QXTsFUYM zY@Gefr>^qlk9=Y$96$7pzWz?l!19FhXuZG3v;|g4+wNKSMk$wT@o-VUGATi+`)ia} zv&$EsZ!~O}R*RGMe3|J~_l9>)f3qCBH9s0POsiNDCYNhrVl&L$vh}V?Ibh9xmbE)l&goW5oR>Dyj_eh zgf8kF4cZBvA%;;ry9h7MMqFJD9qEm5{a6Y-HT;F+$+ADRKDpWO@p08=h77zgkEOta z?jw-0J#JdVaKji0;}80!kyAzhJG~Kim<`pqy;m&p32PECdOZz@)YgE|=ZQD%a5lm} z&M(d%@(BiREiA5t<$X7Nc%*FOTF=bD`<)|9C9V5o zC>-n6(&p5%jjD8^tLXGw6Hg`>YRt)AXWG=YUK#+EaVk$dy^$7}AenNA|56niGhb z2Upze$qyuNbV$}k-5W)7#`F|Mp!S7}FhoRVQqp#RG{QjET@O<;|I|=*wEXU#G#(_1ms0bkEnbFmV}l zA+M+NQ8VGr%^+QcoVVV@DzeR42|9iy8@dVma<8+wJ$tL3Io8}W$Lf3BHoeE~Xvgtx5%$)k5U{9|zh2vs+wc7jmv>8gR6LQbgr$%V)2M~!D@P40sL?1ohdxxp@)7b-Cs-qCes2pR zsQZ451+!ezY1WUaZIi2-wYS`7!i3M)HTJW+2XkerarMc(Z=Dm5WNWbL&O5@;^hC`Y zaq>_{m$8O4;KJKY+`QK_TuV1u1$uArdhG|%U-IPGUAGd_=Fo=O0DSSc@WPnE zM|nkk=oiOlnFVSrzVbCS`xXVg6t4^JYSWO1YPf=S&YC)YX%%L{u*XzV6@VeM5TZY@ zjX*dWVPepLjz8kf4T7POpNwUjE9M|!_zG+TQ!xEm&91?07pn(Et?O z)xk;Qw72XXAqie}VfJt>ZaiYC7MmA(>TQYEFtY)K#%8`hs7=rc+3^+oixVX^lkfk^ z_Zg3sFmDFP{0lcppuqh@ekW%jU)r{}AMi%1pcSXMjCMmy9SeC8mrjx2#vd~GY**(TXD09zlaL&TG}S(v3E4GdM_v% zo}4*-onRWik6^EA-o3HUJCJuTd+$qO>^}a!u$v5op<{&ZH|39E{FpVbi_ShniEE2? zGm|VFD2{KDG~sQoI6ytg3PZ&ftLk~7p>OK6{QZ`;by_`6UFO5*@nn-*>Ew&E;zA=g zyV=f$^i;*#AfH})Hpk4Psq|jd3H;6OX8On?n@O*)jHB#r7V@$u*_BKL9W_gnMvz50 zz`O*viqjt@oChFv#<r<{iF~;l(>j0%FC>hjx;6I9;y}bLY zk#9fd+fSL9wREvH_{k+1+p-TT*@0dUtF`zTMNIWQxBebC*!r7)b8;V1+jLu|j%4G` zXmqzi*epKmL8l)~L+jgzv1WdVwFP{oof@F!QJbPQ-?UkZAbJj0j0QmXo4B}O*k@I3 z`#kx_rFT6NJ|C4|6eeC=Vtd2YF|sl4Z0qRqm<{4uw*DTPU~e*^ir3L%36s9=4TN@R zX%LidSaLN@Xjgl03(-`3Ts<$2P|CGpP2)c7>5cM{PBFrmNb1<&kZQKKV)CM{hM{`) z7L{-ykyswZ>EFn@h!3!V+i)fsM57pylI^`c+8*X4GP#<)-I^7MA@NH{lDza)=SIU?bxw_BW`fh`!j zCTk`ygT*N@Y~r1L)5GOtX&f$_Djx{XryzwOPo)xJ)Sic#3G2N-!C58ehmW|5hev3p z7z&7>TX-W=*|h1VRx=^Nx9#}Arw;HZ3Mh|Zg{d-!Y?QAcoZIzzyR{op5rHPxRlj^GI8r> zcx+7;i-sZ`;FQff1d=~q^%O>`DcW(mYePWu$#gCc@XVLDVgAE+Hc)#O&I4MRNCRg0 zEMJ}T;T_NK(l~Gb&ffHe`Rk|kUe|d&!Y2AL!#(`5gr*sYi^-StRLWF=ou8P;UaTU=v%S9ZbGRR7e?t>H+_nW zF^kBdq?2Cv2Bxo>g@D=5!OgF&`+OduS1YXHlkrR0r=wNQabZr}AZb|(s!El<&-+`T z+TuOxvj&XS>ueXBGTO5BR<^33P@ZgB5GMdNeP#qp@?y4Ev#BSOGqt$ejoT~U_7hCq zqRZyludDbuk5eC=X2X2`AAI;Qw$~{3pA7r0h0V>lduLNm$X_PiQ;WmDh{z=0y@M&F z`5zwk!KBpayr&t-$`^jqpndu62Fw7vj%hB zU%)rZ*jPF$r(;iARJwfvn+tLB;?sf{Bol%?dGprBdo};v@(5zKH|+C1I~=kH8sMj~ z)x-NTX0RLNSfu&bTpxKZmO7?9t6_H*r6L_@&H{h43z=YQ+)FCb-~mEYE{{MY~P-(y3?_J{Wyo*=H79`fki z406r1Fh3QI0YTIpr&4>mW1$Jo#Gx5iH~G|yNqB=L=H@e}tAHXKI|JJ!KIrMBYGt1e zcw8AR#+-BrD&LyOHw>OxCq85yxb%!6$yFfPHoL&EFZfdI$!~^`$2zq10m|-;qHQDi z$nkje&UmptlQ1`w#*qnoM!XO|M*N|Jb=m!7_KUd${P6w+zgKMG=o54o9QvXROZjC5h2$ zw3mh|+?KO!81oj>rHLZa$=2{MC9qTzvL2fgVP%wljyv42+${M%#Z}Ydqg5iW#t}ga zy%&S?x@~vahQN<_I2d=28#o}@u>^@XvRjj~=4NZBon*JDA!fqDn=u#uC~&)`?iykF zgny2GV?h*6ENRdJHR9c*O`5dX7YJEvHor5A1Ep-aBm}cB&+cs8*Q4#pmM@3}>jiAn43;`vmtwt1gb@@$Uw&`ZG!aM>L#ZZ* zexzftCMeRrp-6S@#2)|{p2o$u{hth|8vGf6Kk!}JG9ZXxp5yQko1W~4YH;>k(`A>( zQN)7*--8X&A9ToD9@XH>TK;he)Ne8t&Al5obwfOgtaZ*mex3P|H{pJtcDXwFf_dA@cp4*j%=-34qTB#*zB5eV5;?W)>|8vuE@d zxR|zLwp+N>+=CWXevlmAbSZFKbSXx}`H5+n-KuNQaKL43+vi)6CU@^5{8}&8dQE z4-`U+3i@Wgzv_%7Z3;g~_u-yKnX2ZSNBI`UA0mA@yFE2mh=A{5npAUFSr@OZit*=K zNu8|F7l_;tV44oWQ*qCmbGFo_#k#{&5T2UxafqAqH;fP3g^-`L6P|}@x3pPCOCY^9 zUa7b^uo3-sD*2CvzF}NE#W63C--{gEy5U!PnZT2)81RM&F8cXcDU8<~a1Jg!d=zMJ=u`$|MzppV|)05P&Ulxk|5#afQNKfazHhQ+VbL z+b?_%3_G8Zprki?9Zr^Tq+!MjG|4-~`VCdePF@lECeAGEJpJhmjjVa(>W`pSxHn@X z!QC}IUgzD>GI!PWL(N|%zU#@DemN;#Cfon}zu}kXZe*VuAF??AD9#Z%)qHm$-(hr1 z+K1bZ6r;MGBZO~}=AZrjP*lT&4;E+Op1 zzU6JBtEMQvV9qk4f2ik-hxlfw>j9ikiS%z6p98C$lkNX9=0aU>I(NuldRN!J9=Vzb zW6w5P2l~EvUy;%G@SXV4IQ`ZiXizTk#?3C1L5CtH3c@+x(BEoya2pqQ_!U^lhK|+Z z(=ZLQ%jyuJX_v@RE!(}$PG{Wvke_ICF&;d4r-hp_ND@p|P5qTEh^8ecrRtmg{m3&X zG>7+9)1)v^h%gZWsw1{z^A^X)?FaauOq%pC-cHU0mg(akvJWtmQYO za0^qI`sl;RhoV!4X~`1C7m$xSsx=;FRvk3N3v<)2Y>2EVj-v6$=vd;5#)uQC3|Le5 z*&EA;WeZC-3*IH*qy`q8Nn#V#Y|30M+Xhn=?Q65wemaUFi*M~CJL=jeu7B;EK*pzg zhP%1&0%n0NnQk6WhQsG?cja(*7$;6`@wQW2ASHxDVG8m2IgD?cysT*!EIjy-D26Z% zmJQebGn@VeTqug*_6TE4y(DsM)jO7(6zyZ-^72E$d^hrIe4V;C+zH?u-?U#Z+eTdw zh36fXc_-Y_QGo3C*#(0Hw_^fEB8@&j&>KG5W9Qt(KO+RHn8t?lPM;l5`Py{fMGT&K zp<<=`Scg`CsN-{Zm`I4^y3{Fg*0#pBF_e->38DWe zzn$PrdE$5m>*ys-Sl_hiiQ%_Y3=Z+J6a_S^xRHvu+4%f^5=oMj&-5n%+CrrhNW zdx$b)nq+<_(>GD2BvVZ zib=iWhEPEv0iUA*yZkQ$knENJJI9 zC!c@kWHdyWljA4M;{pLz%f^@cz8?{G!;|IVhF4hw8Hh;t(zKVk3nybqV&OW_HgbwQ zh7-ql>iQvM(&RPLgxY~k!J7X5Cb5S@neq4$i9Je zxY8d40eN1<^W9s2ST2)M4w8VzWNnUyH@nt|%O-P;%Bh$G^a1edO~zhvFkvh<5kq>f zbhEgmG>9km+s>!RcSiDExqtk1Ukj%=>uOJ27>7$m!|Bf6FrccIx0mv#DGQ<32m=)3 zGd5ckZ1+G@m<2V#f`WYaiF;!rwN4Zj6ARTCUskvqu6@3)XqfO(`TlZE0lj-5yK-D> zz6t87pN5P0%7jE-C)GldhfG0n2!X?HO3R z*UBaE`_&cG$FX%}@~4_@jDwIcw~95q3B<;z#L4x3sgMTp!KdsDOJ%j#_(-vd=Y3#P zJ!Ge2{1+V0eUaehMW%3NQUkNoVS7$zmiV-M7G?E4xL!C}G zYABT56@PIW&Iu}|tL9cFDTN0r)&+fK(Ctm>;bPpbq3p9szC|5_hCa+D*ONg*m|LVv z3wc1{l!zaQt9uy$^a9xEs53`Ea?M<;%H~_SlW#iRR@IdUeuy@j6Q|+1>)hEbbfED)Q6RypZE2Kdu_DUVB~Ubv6Gu_dJsl$3Sbn zW{V^8mqlJj8aTHE6>dSt=SQPxRUOhqf72japiyquykS;nrFg+c69*kyr5wUzbw@U~ zi}gw)PAwOM0tMpqSZ=9gaS*>Mm&II9d*iMr8}H zv7W7eXv7gYH0L5tEfga|9v>pw;&BA%afGY0Rf9;uxP$;13-WDGAHi}%%2b`&Hy!x{ zKZvp4Py(Vs@#s1ONnmfaOzF`-w_vtzU`itmLN3G5iqv zrgRo2?dbw@#vbu&58r?kkpHi9s@ckie`oqM3Alq@6c~uVi5`6h_%o@$rd}K-P&8p! z-yeE3(fVtzoLJ zhJawF&twV0Rb-sNa zIpRtXt_vW*O#dwo8kRblB@I+3 znTnX2I8kVX(hRpmEHQ2C@rv{MBtvY8YlE<)fJ$_xChPzk-}t9+1%(J1;zz{POerMg zaCRy=Ada7Go(j^V-TDh+0d3U?D}4GI8`xep0*&dYfC{Of9==zY*UTWeJ$qO=)@eV~ zIj@fx*wX};basmQ0JeanN_C**x0i;=-U`7IHu-BGQb7bK9B_&0Mo~qw z3Fxzz{GQGP#+MBONOeV&!A%<^Q@Ovc;L>ilpSIn5TgUz2{9d<(X$*GGO>pYMg2n-^ zIz5q3XyS#X1-H|S<8Cru?${dRSzHg;bxIdr=!|_y)O-OCe;9< zVGnY#*plJHxH7TMGU59ebdBBIwX5d?3tP$>UE{N&^pHdGrRBo z(|I5w@`BTRI&SHeuy4EN|0`=N?|9tT0X>fT_Y!(HlcGI;V2fhzc> zag;I1VBS^?i-(!jJdE=f76<+jz*Wu0Fiz%=9*-iTm= zEe#s9a9lDoI1!HUDUOw@t#G5fOY?jOFfs+Ler3Na6D`Tk%^it>Hr{9*RSu1rndu%` zI1`Ai=e9;EVgU+|O%gMi)v))GT_wGJJ48eA7o`yQleg1oFk&+}Q48{f&j0wGZw%_C zMCdzHO0#`VL7eTge9LKcWU?$}Z|8liSphQm1{M&38RR-4?oTG0e|^JDUx59MQkG(_ zZ^$=#lp?{D08H?;I#Whn?#CT*YO z7y|jJb^CbNo(ffS0J+ZtA6_VeAsAktV4S!mk_}zSbJtvQz0`Id0&ZrQ)yxs0-K!gH zy%{cC$+plsHSNdWw6H!R(qs@+8wTh&!oM_yd2!FlujHfHZ*~3>l8}0`KneLSa0B=i z$QPa#BuDu9Lnet6byKAS4J2d55*}4U14cvw93nVR%q#ih^s0hr2C(2;;h5<8TeB`1 z%XD8L$S2Eu2F_Na#B~V$pRvdI`~;PBh!2h86$ge~0sik&Jlw8e_;YLgSpPABMYqT^ zlHJwlKLnd@Ctw7JTTLEskSJCh-|F)l5Br<3uRWZp0fAd30<$hbB=mVXzR8GpGE_Qx z1~&?eKbQ@oQ*nInsX9QQ_Pclpc$i4p0WwNx@Lot???w@NPU$qu>`w$8 zXadcIKu-O_XV#N1pDAK$Jb3%w7x}&9DE7SwG|J*pK?QdtK>hsz&|N4Weg5tdg=6%w za5ipaU3=AYWLM}iw1 zy<4A*ZY-nMUnR1!F!-cj#746OOYcTngxQA! zXylv$-vl&>YkLS8I@tFB6mS!_Y)Em0JE>3fIUZj<|00 zx@b616FY6^=oX7hel2@^1kuReWtx0HWJqw8v_`}gDaypgHiLB*(k!GrpCyGF%TZV{bxYEVf9!mK(L&}{SC-w3HiFiBT|1#Jkt@J7E|AQU>6$jcLG2%UG)uw?FF zL%9CUK3|EQaOeE=L=!13{#Z5Wd z%B&qAe^6Xb>qmmaCHT_&wwgVFN{Hx^b;b zOm%6G7#McQkVoygl-;CrW+wZcf(z_`rm~VA^>fmcsxnZan>gP)t3NmpjF!PyNpV*` z`8}^zA!g6(;DoA7lM3@dVru5Y+s+KkuG?rhZXT!B+!|%R5zbei+=4beeoN&d1IqRo zekUj0Hi7Ts>&Ca!kSQ;5HBn_5Nd{YsD8pMA0!YNiv2_twW!#WA@hAuiN1HuLJe1-b zP#gUX0rnYyk3ABY82C{IscNo4(rn*vVY}o{59SS`;8AdRMTTx$zDaqg?hR~w%LM-^ zbB#(Z=cDn&Qim*EJ)Kjgg!`sFQh9kgG|;~gwUzQZ?HRsH;cu7Myq$P+Ogu;07=$yL zbv$<7p`3vUL2xv)>TrgW$F|%I*L+Z1+@>Wz6N)FBP70;uWRNDOsv*+ICN9%v8u{hX z^tXh4S=LLF>h|~U;q36j_)pO`)!UO zxnBqqFziThWCzX){KiLYtvI^J#DTWLIm9^QgDkLO@xYC_Y4L;gax|i)~6>tIsA1!d2`mC zziimkgh33Z^IHMROu&RxZZGDWm|>M-OH?X|c;4uBj&Xc_4|k@)ubA>yG9eA_Y)Pqt z;Mughr9-70NSh^tBb{kepvont9m`7+=nItlqUX2iqhBYeW4R9)NFKo?^~H=6)zlH& zOnC|ET5dwkIFM0W?#7QkTVHxit#iXnfoU(ZU z62cxvfYtWE5LCl`UhE5rkFtcELm4_brJ$SG-MpgYS}Kur4juGKJ!L64epspwV&`UbcZ2?I1S8-X^KaQA8p8k91bTPD=;iG6r079+j+=vB8?5T zgZJXvzY7qcON9ftc*PgI?53h^FYnoqp%6!J&aq)YAdV?)pr^IA={B`F!5E8azXeO{ zeopCrX_M2yF|kV2SCl4f-+?FSb+(Cce<0?Jc~A+b>Ms7Kd@O5tb`eEf;oE7Wj#U0_J?b2PHg zwtX@oW(fghoU6JD;%ULdTx2sEEwmD5nPNmiuPjxy%K2_Dla56r!5b-MUFj;s!!0ll zM}C5VMEr>og0kk8w;JK#%pOQ6q32xi*D&ci4c#u4@6cvMR1jozDjr8w5pzJG6?pd1 zWr-Vn?cvGHd;6O`CfEOFj~T9ivnSaL<(;{YxfMj-G}NRI6PLOcmX77}ds@N5CDWd+ z5y7>~n5%n(o;7y8Flu8yV9hR*nrA@<1fJSZPdTb?y?yrSQ4Eg7iQPinMg5|xWyzYT zX2&b|xZVIf(QvUn-j;j;vAstC|CCEbV(U#8(s`c<GT*q% zw-^~f&y!(R;2LuSSB_1$_`sS7>A0{6)KfZ;L{O~LawJ#HvMDclsCY6#X2! z#YW>2j4%I1Qb=G2>L#heVBH&i=UBN(pTgpb)$1^5(m7GG4NO!*5qV}*{Zt=jSE943 zXRK6`t9lOMg4@ocJ`S553LVP!+B9qekTbrTctb*O;DHPij4~mG3y@M_Y@&WR(H-v~ znbnZGR4r9TC0f+WSn-MJI?d0rCR!@INOFRLw_r`v404Yt5s7k~)Of$L+7t?$(W?;) zO1%113P>3O)CBTK0A8WHV{#8W*!~q1BUhL58-Y`&;LH7!w=sNAPUN#c~SFs)#ymnc*wJ}f0nazw0P z?QuW!`yMaIDt_6FOd-j7FlF&?d!2D{WL9|9UyGB#06LdZm=dCdb%otDy z>x8*^5314bi_l_;VW{e2{y}tpN49Ahjnko&cr>?9n?sPmpr80v4khyNCN`>MM8+2V zK#5x`YiotaWCoZm8(unIu|vb514689JKB~e^Ancuf9Y9KAc2l@nrols=nxC%J6H<_ zfapyQDmO6(knRRO++1Nj^W0&koUIy04&J%7Rp?He@J>CQ3s$f)E=dr23J~dJA^hvG zJ&(iOi1R~ooyQ*t@hmWuIB+r*_udzNlUu61i;5)ygH7nL|ifJ zsUiHe=}a8oV+=lQMN3vFYh>^hEodlKRI&cs?_6>jtK!@#IS<-7zUv8}uwQ_*Sh^*3 zPcJq;Xy2(G)?S&{Sf>dfynis*qE2phAa<5=!O~_iw3>gKD6*p(lbiSO(QMiR6UqcK zNj%WsUoy0R_V;`fmi7KlOKd`_5cFU=RL~gu{6pi0*mCDz#rC4PST@w=7{Px5k8hN* zec|?^^MQz(Azl{AGd!QK+D|*WK2ph$RQ;CwHnB8PM;N@1CNdIc3IsqOx5@zZe!iI} zR8l5^IEDIgYuM=wY{E#?=Lcl7I(&rQX;{AT52O!F!bX_>;E8!Za5s&AiY2@}IC7Ou z4LZz++QV2z?dgQCnhaIqv2g?uUx)q@2Ui^1o?as%y)w$Y+8}5(-TfgWuGYonl||@58?234)x02HXD&cCYm5cr~xQnel~s4 z(%ez*J65?#CqmG&Y4WXS0(CS*l+@g(7WJF*CiS(*s9`SF!odzaIUhP!MNM)B*oVPr zb|l(n2?Ltv%%;f&oE7|Kv}A%KuJnq-B|f~aLP;K8hGUC^_D}nDXqJjm`^!{I!nX97 zi6g<^xY0(26NUhW=9-EpWLEL;{B?&fv{qqurW$r*w>D8U^WyV(d`9~X^xKMuD@gve zr*r|$Ei+!Juv*<;=jdU)zLiy~_dN!YLAqY#pFtKB34PK}n4cQETeOqlm5F2U5l5w0 zYkt(_p9e+!MUYS4ASH!Z8n;CiH;U#laE1<~BHmUD>h(Rp8R`c~R2|CplsDpq|B92< zKsD(<{NH9%t@w4AZyo1b0sy$0YBn^=cv%%Sl;ZXb6%y+|u;$?bQkEP0SbPo9SS?fp zi0^U{nhV2B-*;k>sDKR>#qG&vw}Gk%x4aQm&Q@U3qz1f3@D_Y`r4TbgoKx!odBX|F zxa(&7gUBPR-Hx9(A11&Rti&>5Y?Ed20jjwa3k0Oz+mRXvC3$A%*5r{*$(I$kWs8Hg zT<*=cDm2%VX4d^se0}|_!T;j{U%i&SZf9?3D{LNu%!`JYBl*6|NI3_op32{nmvMpf z8@{)gw)yB(JeuWl^T4Xb@L;Y?xH{5}sYo|@#!afSo@~NDeY4Lk>u1B~>o55j!j%*X zgj6jXJCmsRNzLW)gE zA5CB=!vJtcMil!25h9y~A^y@ZD*)*$PNxa*It5t79;6DO>UvkmtI(5K83Dt0>e=Q( zHqFWd;IQU^>};9jBl>zQAU1(qwY)W%&$=A}0R==%BeX2b#<~r^5gtZw_}OoNL-haj zo%zlTB(87!6WiZPVadUX>({Oa04Bp2phL|YmJ3%!z?jcV0QQ+t$>Wv6MiEpoBrPVR z8c5gpJLIf*gPavL%%Y{v)|)13f(V82t1OeH#^Y|IWY9dG&*7V<6RWX#z-kXXPd?Ok z9vb)0uW#Mr!IR3mznv>CY*^`GdfK3rJm45A5)J2rHnag;tyq6NIiry#%1%m6MZ0?L zq}Xqys#t|EtYdR62SV4P8GB#;5Of@ZGTmCb8`s#o}>j zA#ZTZQ&I!>3e6{T@GE)efCtN`@rVtJ&_V(b*IVY#42W+(`U&(Y+fD>LtoXD?r$MR_ zPlG|91~fZ~?T2}d1jm5gEx+*VWuO_YOMKA zn}_Ay7-qJlK{~%g+CRfqS?e*?8ek!fF^$ap`GQGA zL!lDOPiNG*Hwd2$+KIlFH7swij*3&HBr!VG^9BtARtkq<9fee7pnX!@)2o&L_@Dl7 z_1Ja4@WDX8jRE2?A7v=4&>u5`n4#`Kh);K;?2y7lbqP?2F|Xq$90w>5PQ@UNOygy8 zCj25ZZ3LtOsmE;tg_aJi^a#WLhAc79py4kyP<6@6ijA;34OIv%4GrzLq+QTM&+bvx zn^_Be4BeQ%1@^hII721dHGeU&pzGw?VUTJliyy`U8&))Q|1#{VhbfL?$#|SLL*%yP z;S2Y)vOfA!mW^)i$DH4te`D1g{z9Cffyrm#f&oHH%zT)xrGEI@Q|4SMvBW?k zBK=3<`kmqEWk=%Ypku4q_9fsZ&&S8e9+@-xs9y7t90>>7p(VvaA4MwcxE4TW_ zmX9qLdmCt}?hO^_Gt@GCI4iD_d-<@pWhyK3c(~q9BUK{oydIeSq|MNX52I!h{|JMa z4IsgAD^#gXs~PoP)IjW{j)%j$&8MLD$@N!;LL&AON#X+_krV=;;d!8vD02E_XDC0$ z!Wm)m+^XwVJkUdg+juj7q*TDeqXtYW*FG|x%0xC48GbJ}LhqoJX~f;Vcp-=08%c)S zPUO6Pl1O7q4WJ@(1qpne`=^eRL#DqfhRY13-%j(0#}(hvrF~j*hDmX<;%okV5i?)| zj)Ogj7U)T8YCu8S8@DV5Act2=Retv@mFlKlt1W~inG~Z;wMn%gkbXkvmu*YA^8FqR02(9PPPdYMD-6gD zT`M&t{*N}_mKaJ-}EzDAGl9;>bQFT^6DZ2XP#W1#r%x}sq zxtNPjC!sw~r6+a@0PAYn#;1Fq>6m7+;8F;c@4+ksYX(wpd*Mxb4JNX##pk!xzQ9;* z@X{kXgno7}vhM=qk5tFwtJg7u!%U}4SokUPMd6xvvUwN;Zt9nbuH1ja%1yJRl0I{> z+DW>0Saz;cwaT?$s=)n*@R=!IRy*pCs`$8xsjsg<5$ps3r`g8j;f;swf_SfZS*xQ$0z<*}@k z0omX4!X9x>T-Kzznb^Q8plMj)069wh%NOqW=OK%5e9-%tAmFI(vq3w@PP~ko8e2EcAjau(*WhnZWnQhUMoRgD;zfFrAsB8YPjLjC-leLbuiC@7bez^3 zW~wmSo~$ebk*JqcxYM(mJIs)}XEnP4r8|tzk`eH|5l>qNWYxn`S1EeQA)-DRB|AIj zb0Sxsz5l^@S!aUFDWKvFP*lf#;^tdEL(%9}j#3^bO`G*=I@2=4tHrMoS9@M>LC#vYcEn@7`W4PJ2Du~+5u`;v*WJmaTqS(hx!Lc7CTP|ElT){v0D#GkhR-e{vQf8_` z94Ts#Q7~$e_gBlP5h(1W@qFjO%6iBg2rt~N&OGGx1=eV3A)yBJXYz`{`OoAe>11DY zY3I@=4@8Pf<2>Q7T!ITw{9Qe8zi!ZyRnVE9EGwm*(rF$yW2on!7y)%}NbRWH6R{Mp zeW5Ci8CjBs{7Rh5CnvS(ndGAX!h8Mc`~UU}Y7&ugq*wbzhFuC=>eCQ)4_%J8-;gNV-84Y*Cwva+i&)n2PlASKtAwT;;m*UVNqJPX^B znAz^9C!)|m(&>G7KN)KIBT86pa9lUlaBzC>n43g0u0dE!8Bx>7gIr=sU0}3~5S@cD zwll4Wu?0Wb=^@_Q*bCm+m@ zrI7dn`Aj7OIsiC{HRS|Fd+s_(%6uNM1MgXa5DD$vnlRc9qH7`<(D!fgLAEdTa-FS1 zDk9D>kEBK65_|(A`8>KdUOzL=D%rWHRsb(G^pljve}{uzc(%QS@pSKsS#`9 zgnVPlzGXdMU1IW=_$;OpMST< z?DaqYZfAM7bU92g~jm4;JN)flDc>pIlJxO?ydJZ#$P5!#Qf zO9ge?pYFyeSJ!c8&B}sAccNJ+v$+QgcsCmOC1D5Ru*P42*=4%W2oT4%g zICFDv7(t@$Z35{wY6sFLrgB;=@}JV4_hll4IB=eGV6bJnw!|#bZB4&db5bTB4XK)f zIgPTdQgpyfxPKC&@p2Epi#Qpq{8XvOT|l$GK-YLkx5SFJrXOY5!>QS>z+JW`TQ+@_ zU5uK36YOUx1%E22u&v^Y^fQNKjc7T(KN-H+=u-;pe(xFE?|26%@GFv(9?7^Ne*%M? zf)h}jscb9FiGhTAQNzfiH1bjRbpIe(5Go-3)6g0mHkQfl_K&(Vo0)3Rem=r{as!X!V1a zM|I-Lt@j1(tV=Nsh8pn?^3-pWamMlpS_*i;vFJFx89thZ3BfoyP9)!y#D~?ZL=rBi zVJUlRSQisA*n$CJ7nVc`TRNZ*TorFvFTw<9WFK-X3Gd{m0s;DX7;PRyD^c4r(E=ze zGXRpcEdiaXC}6Ivj3xy-hwAnN5Zw{h2>C zRfoYsrup?1yb_Nby1~thun=TQy_uWuX^Cd^(vIHR3T>n|v_@YYzz$ie1S=M&(~6C@ zx&G-ws~5rY#yTp%Sx)TF{>hA}5_ZNK+_Z3mTnLd5uh$%b5hum!^9$?t#X4w6;?(_+ zmc=lPObD1(S*~f_<&@Qh^QMFb^^Y~ul=ij>IA^tcGPA1gEXY9ch6Sjtjz0Qd{J=ke zdWy9+kKG7}=eI$+Oz4j)PH5>DYR);Il8RN-4bY9O;oqL%ifdAopgd_q`1(7Y-p?`5DiR%GYg1_3h#UH!aT{O@0BY{*jM$XGLW!& z<+}LLR?eBd58vNYGmFk|<3k93%Vziq=GF?+h_~lL5v@JrV&a+o>WEZ;N?M8RWW{>z z1SwrlYdpGwg$)mk4@p6DcN4Vew@`Qx4=F9tOt z`mboyhAiI0^dI&ij30e)lOUH2GbkX!kx)>R0E=n0GQw%T#NEnUwyF^?nfNvqi~`6w zGd1#0y?162Uvo*_t)AU@J%XC<6I4IieDev<-fh z$81hiC4J#;ClMGX#@XYB91SLw#Ribe6r*>jL-qSzzwxMatMTlJFimT}OjD|wCNlrJbkuCJ{|Q6fyN@N@)7P&$VL0^e>w@Gi;|5`)|C!(R{96EXoW=eKlo;BeQ3 zh#_0!SMq4{id%V>makyIBYj)SAhq*HMgp0ETH>G&8&HVR)WeVX>y+y44SY?5nTbPC zE@aj%3Va&oD#;YA)}vc=D53I#dl!sWLiT}c7+-<)uv&{R4aXIEaP@%tvfm?y9~5H|oq=wOq12Q;lHipZ8vQ-a3Ej zI_9xsUHM(hf4TuNA7;(J?aAm6{{E&J4aN9rydCeCli5Z#j_SCRYd?g+*bckka zbsrjG=C_GeGn)50W{C(jwHHYx+4|A|CI^QonaqWB6yrm$F{y##BT}Pk5o}6Rk)e;f z?6H#Kh0F+GgTJ8ORZAgiaUM!Xz8IuMeXypbSg}H`8i}xpNBpa8*l61C7>l^J7U|JK ziwiOYdv8NVt>hAJXH6s zZ8@Gx*jUs8Iu9p&x&h7Wyv?~3zjGG)f z6-q24%iSe^=#+$QU%W<{+gibv(u$+Tafi-+jzlVHl>=$>E5KY+&fPpq#R_Zmwi<(z zaazW4#Lf40m#L>>(J4K)()qzS)i7a|+8&zp{a!f>lQAV2mk(B)uoQ&kf01az3|>ml z9-16d1sy*>?kt-+6@i+0V75^WgC6I=LjjOaQ)%99tWsPzId(~+sA2QQQj0j0&+S>+ zFN$z3_Fk5l3HEq68EKPqh!~jB0a1NW*-oPhM}U#((ERUgUi|}t?s7m-y3&T^nq-?a zSYlTVn*^!tl8+_X!9(Kj1-nr4A3xeO^-tUtu{{9M_#wlvxDPNW#SrX8uyKWxrkYo& zYNte*EEG>OYF(hqEZ|$xs=fXr3k}XoZAEMN#1_&zf95&4X}{$%<9qa@iCES8vV7L& zvjeZ9gE&Tp5?6)^nOfwG>-WM095TIstlL2i=P9jLT(Lhx~3@2IXI&>@=AC;dri5 zVj*==Da*Cpbytcf*vjVKX3Y8{tHf8i0xDz}PuS(n*#B*q=yLHZyo@KeN^1xIp&2td zoqjb1Q)kDJjg-KCO8pv7@$U)3DU?0Q&~NQZw`t4W_dAhAx&x zHZW9Q{Dbi7hmoXOfR!xkZTyj6cW)eOs7@ zO{1D03U|0Nq&IxVdpCH2-^Se3kpA{+FsvG~%j1j4?4QLO2q8v=x61`mHF&9_rA#hB za!7h`a;TQqRSgy0VUwb6rfpc4O;sR-VqzFlG>vxt_`9rZS8)f%Yv4&~STVeBNIrx; zhz$=sB)3JSP=A$1_CB1ckQ=LWK5KK9^4tcRwk`~!s$}=zWdrDDEzV>`8tsSbP*=2U zK}om-{my^mDUEu3o|>YZzzy@^6_0&*16$AxCE(AoaQYdhKB*4jHeCL=Yv>-y5Q&OF zu+w^5;m@?}42wO2Fj(+Kp15Io-e5`LcI#gvq8iAz21_^ia_=5;qN}`Nm$EvC#i7zd zEl|%!5Lhm%%m_hVS3EF+Zp%&tL+4~!EK600$fE3H34KG!q-Am9jLZ9{x6?eCn^Hpw z`DB=lfLW6VrVF7$D0`P!>WFb*2Aql|ahVp60vwvij_rdm{}sGys`Q)PiWK zv^#rBZj2IIvR6K)$_xV$Ia|o4x>fy}B?}3#s*aVEZO-GzCvS>cxh|%g1oHH1h!ra? zAo)jwu60w`H_aWwh`UXWpOH!B@sDj$QwsgpKn8eqDc|Of`B2N^sm$6X8a}oegmIR4 zs{^}{DOxIgJRt)vQ(<}qRaINsG2bPOs_KB$9Ws$wEhy;xB6V4ek0Y&;V!DE6hC^n`*DCOW>lu~v zr{JgIP6CPVCzH))#UFHM+~mz^-$@}rA9~sAYW5~oPB}KfFEzdv<2>{Xf->Bz*r(w) zR%r2lw*gDn@f?qf=$Ww6ruDU95#^Q~$w(>Fx3B|zQP@A^_Fv6`8L7~o*@h+y&PJ%> zY5oGy;jsqv4G<%q+;JmENZ4t=S=Uik&DlR@@!$SACJbbcak9;I3aaR(_-i>*;>7?8 z>-_bX;$heQ5so@VMNir}k?5B#ij_hJ|C!C1l*bFc$=&k7R8(4sp_f-}X`v-$uK7CSf?RDqr+7%gS>^^_cv^Y9T#1qAM;hTojuiW^ z$=Pd9<6k8$wbsmU7+Mq|P$O>pYOpv6Ajs_$J)EC4DFOxl!_Y~n0a{$qtijZ_fH8cQ zr#Pd>P!ge*4TPz648~1^%40740G%n;swvur$Ug2|81shz@I(A)BfDI&+BY>{xX&e_ zm*Vrb*GBd)n(dxD;1)J__+h=RHIl4m5k_S*i=Be0>MJ3NsiMM2+S$jXXhVeHS#{Kw zY$abjpEhNCG-O^pGyXY$N9R)tdMHf52UJ3+IbHn*7z?X*k9yxJtzrzLhmP>IB(N-Q zQdqRj>UIiD7UNS;jC9WOLY&jh!}Tg6Rxe@=|5QDUtvQ}q0h-t{^RA>16+_4USTBW9 zi*-Gi0Q)1R68RU8YRm4gXA{5VS(P~y7!ya0(0b%g9^0a=yHKPd^TYNX0SrYZ`S;HBY*($M*CHHrw zoa(-(WHw!#qie`*Ry9>27_g_4U#C-6P2azffyuid`VW7V6mMzC0_$8iF49siSY26y*GWNY}yrVC6^$@bs7hzGr=%`NW}2{mLT>$ zS&RhnK_1xGz-2vjIBx1RYxsmCcPB0IU=8eVcVnTOS@>tb8)0IsIPeN(EEp!$z~8WZ zWU_?G_W)zuz<%WUZwCdl*~@H!LC4cR=bu;LTy(7OL{{nC+lw=RWq9ycpr#l7w{|M`Y^$7D zKSO+-;LM&amEt`Q9R*I3W`Rgf6*5ot*yJPGD#86(;6!JQA7ecCS0M*kvSH6EXoAJc zb<|jsfO7LDj7b`?YTL?*!fG@wxN_8%H7vi1FPWd=6d#!cDcA4MFN|_=i=cYmBn!fN znBu1kUu-2v)Yl%^oHxz}#BNr9wxG%8 zycgz(%U-5wKsTAaWSI?W{QCrl?w6=_tXAXERvunaRk3JN7>Yw$;8rh~k|@U~{d4Gn z*B~un2|#BM6yQZsh0(F>;$X_jC3C49Fc>P7lpacW?CdP7M{q3{vM0HD5Gt+6K7kdZ zf}lJv4LGV~Z?ilGc)17qGHTTOGlK$6f9g;|eh0F%L3E3Z+f?g>(Sl$EezpaPlmTiT z3#Wr>`Ch|`ezr~j#ITrXSgz<_veTo6Vk9CA;*qp_AJcv66~sU%QDiHum4Ja#LpAWL z-oJ{pi3(ZCJO6GU1Z&dY{4~@@*`N%TfMwjlBM>Ddus%1&xC2@?(hDf)VkPK z-cpShl#$A9YGl~{@Wr686}0)x2HJ+n3p3_a`YG9Da~&3pcI!M$d)@DU=u^)ceJ+yT zVPcFW5_m%yT(qqOU0%XMlqckC1mj*%1)UcAL{yQK$YnhIA?kk?js^0jgkMonSUeaRC+ z2C2@AGemd|!(+9|vKqyQL>!HLL_aLL!2BzL1K;VT=X6>pokqRRbfIf{F0^%P^+V#b z*sVmaF~;sIYc75_aQNC@`-BB|PmPNEJ^RA072%_rovvAU5z9`w8TMfPb)85)!64U* zlujdJ^O=9v(X52DDS+Wj79ZpqcaQLUUDRN#<=1D)sUF?UPw`i4c*0cOHtm2+R%LE$ z5A@sE1}HDRU8>uSo1pH(mCL%YqyoHq5=-`o&y8`t=Upy^>M7xZVDhBv=jEyp=!KBeqM4YuGuHcEcW zG$Bm(Z zw)jLw*XSBWgRG9HhK?=Nf!$_hp$S1@QOk75f0Hk$irTtcD%yzZLV=;Wrk&$V!`Tls z;_(TGXYv+{Z%*#tW;Fy-;)EvV+HBT7B@yg#w*%m4(q&=R4C9A&VjpL+;x0hri51;2_ zOmqgpG9T4MKS|bfa#|+Db6Th)K%z4@dVfH!Zfg~qt_ z#lFPJI+A?O$lW&IP0q7`%v#)dks+m$eX2&?Q$EN!FfE(D?inoLaDv@NQT9_eVX>`K z;!ShL17PfNXHy=4*DYqLRTG?L(A`LMN5E0CFaSG!xjV8pSw~%aNRxf-c*D61V%(oQfCzC9;tGL!cy#a+suvR?`;Q*aG|_>k@T<0bkd9Y=l3yoXbNw62^kM{l{ETv z$h4mxL#py^%sR3>w7nu#Y==s(v46+}lpWG%w1&?ofj+~gd*IG!4PdwidwdyTz_LTc zuQARX2#gvfM0Re|NIUzB7NZy0R5(ci$~q~Z;P1i@RyA?uE7?cp$>X^MQpU1#TQc@& zIMS?!@{_7scuZzYCD}iehL#gi_Gl$RJ;0-+Ux>dbIu{xKE_g9`;`pLo zxe57x#e;V+`Qp>EC+p^qmw4XUg{d$?NLFoG{CMS@{vcTzLUkifz0>07@uQt$QLZWg zA6jox+&{bZ#!N9FUs+US>YEiM=AgffePeEN;d+q<{*~SK!W;vV)0gR^^a*9?>tv7H z=_-p_&@hoJV@V*45u?*hEK+?6H^DK+Djx~Y3{`CTy^a%+QV0=O=J04fBG^ARE$y};BK#2`wel0C&8*V zz7BroFCvT#5qOkg>;Vr73JOm}!Z0H*NiYz)P0bAg_ISvAX{s`^$=?(?zJv0LGf*p< z9c4ur`PTD4j5eqoR-?SDJ`4FMrJ$~s*@ktrl7_EU7*n5E7Gf%$agF1*uGqQLZ%QVG z%QV*g%b#sHCelBml7fi#DOWG3@Xd3oIIKzu4-ozrI&SOQIVMJI($H~W_!ox}KD-+LbECqKGP67dmsnZwETbbpWh~k9H2>{EEezc?EKGf6qFbYRz@{-m#jmr zh9UA;xn4@iys3G^U6SJb6a38eW_TM-*-M=n9DYbAdXnQeTlS>isQDg)SN`zqHt1G! zCE->EvLSMRWyCkK-PP=pT$+>wS8=>oNgZO8;oKZxb0?9ywridzPW6o`MVaQlCvKlP zJsg&CXkJn=)ZT7~TbC+R=$kf6&yF{;-x^4D_rDT0v?iHH(ECd<0A^Wf1T0TA+-u7h zI|N$wkk+t=toPg2V`)>9)RUY}oNG#Y4_yaMl@)tJa=X#ZZhYDg(l$RcA_ z_+i{&PnL3iU$V78LG6FB9K8)QhA2AtN1S+QXie3`kY|z27pMeIebUnJN@q(L^{jMErEDLu-6pjbwx;Pm4~A=&pmB8(e|0pylY-hA;oTaHtf zaH-U<0~OkDrLbsNQBVpTJMn??ZCmk5gNa5k%TJiTNYsB%ip*FdSj+_j{aN6~i2d*v zfws-K|IRl3p;$CrHqDhgsr#=~@$RY#RI1Gi%T-F@24*Z~<|FL{+)WSr{xbRl6>rkr z6lXE0Bc9K|m$~rUA{iQ$1a34nmA(JN5r{^hY`35dm{Jt~kU*f}5uh?B2QChY+o^oN z1o*d2ccIxsoKp|_dkB=%!kg(T=1bB(po-Qjx;!mtlg=r*W%xU*?idS?Hf-8|i)osB zL5}H3q#O?!r<*Rivv-!(VOrtMG55``dL2u(bEa61{L;&9@Xo1^7>2%9@Y|(DiL2zd zT8&;SfR&PE>Nu$Ja6^k!Xt5EHe~(@LWMLC3VTsxUpDAI=pgo%`YT-NSUl?h+1WRil zAzY*89il9GI+HKp&K&K%Z@S^AZ>h?)PjS{eG`07u({{yT5xf|uWPguhjq}$oJqO1_5#3tP@3kzpn zo{Yb7T@xLFL)-EX>ixDwPl#^*Mg&>dGTksm1lJ#vn&Zx@NU5dS%s;_Bzvg$H6|D!r z=@zk+{Ee)8>MsgNCQ~-fr zE+Z7*OCqGUok>@%2z5*AOQubMhqj=XY9zc$b%&eJw3L7~!u%<}DDattEh_@0@_=6R zb$-fe{>n%hr>=#3IO!R=30eSMd>}1w3nZJeCh0BX+VbmalvNOtiQ%5N3dY|tLCqt) z*);E?JV8(dL>0KhniGgZ)6?DkR>Vb!G(M7LS}ikHO_|`hs3%XGlE( z@!yKA5gs;u_sMjY?bDK_RGxd{-OLiUou0PC zQ59FSJ?iO-MYpRsy-tzhs+5w+s$eR1QD-p;j9Vya28O52-s^Ps3YPMZQG- zgZo2^Ze89qa3Ik-hIUjXc0K4$&ZCyOFIPdh_S%zGc#lx-ctmOlSRF>J2Q|Q?^*L{r z-7dlE^(;cbL69%KlC8&Q4RR3ZvKoMN96C34>OjBEDui$XI1RR}mxlG#XhrAs&v>u0 z47%KU&=)|pLa=G6?zpj%km;dConwhR4gz|mYP6km#(SMDCNjSj2e&E0#wg5}^lDaY zP$C>URWWsKc}}cT*4CL?d<~KBBpPrH7-D8XlFBmPval^>u*SofEJVk-IXNXKQ(60% z@H?X!TjAQRNcE?QZO40b@*_VDxX$3H!{p5NR)miTbE>hk3_;VjLs}w}2$*`9zh;wH zSksAXDe3`NO*B!P&r07PksH7{DTx=RK;oOL&Y8E}JHqY9pw!9K*;8*1$kn+5c^Sp3 zv&8^OzTL>51h9foCEW1=MhfYt#gG9ig66`EYoIYSsn!~DkN&nhC zQ`-=N)b@}?N!F{Wo&yDPEJ&0;cH$aJ3t(d}6N6Q;CmL>kpyI$332qeaW4?rM_WcIm zq+oK`1C;^W(8;uImEg4b1sCxOs)b2D$SC>A31Th#q4bgybI;zIGOlu+`DaVxrp++c zOs-}+>KUc!DEP4FlcS)#BE6>&^LXCCYc&f8bfKSG{D)-!`$U}7LP~T zgf|ycbApZlVx1$uU)qsLx z{*YaaSTQ5kCz#(+p3|#lweCK;;ij*mymP2}-8aEh$1AC1PuXy`q|Paclx2YtNVE!` z_gCx2Bk=_k(KpFQ~-H!CcAmk;lbGJs3elHD*3p^oghpa-C}z!a)W zDyRH8^u%>&;(X)0E@A%IsG?BL5gZBWFs<-j3##+Jy|_8@b_Vqisv1;}q|z=-$y1-E1eq9;&ta&UX^pUx95YqLVJAI2hc z%nzu8Maa?O4T`*22&dx2xMFS4!+D&M?++Hgc-Csa(vSEeUIBA`BcCClBBaowz<-fj zy7bg=294&SmgV@F?J_6l7-foi5Ft=#i(Gr#SynvA+rLuqkD^SLmFs&-WTZ-BlP!o& z6d^5ljaG#P!OsJfSxFVydAM=5o~Bp&JR@!|ht*^32{cjY1?6L8aHT_19zLk7Q^o0l zU(BZFJMO7;)ax8Gy4Sf4H_9ZF%C=VvURS&)H6W{zY92jV_O&{`$Rq9;lZ9213j4`fDg0ortfMK>S2h}n!v2n3M zor8-xf%44eF|1Xr$=ilck$idOo0?U}bMCvx(N6=L5?J7NQIBL67Im7E(Zy8ua8#L- zrM09nj!iZWc>BAdO!-aj={N7w!Lrj4%fLtKunzAfI^Ss8Uz??_^SV+pc=W~fj}c9| zQasv{XW>=t2iLj8naY$4XGJ7yS19_(GhiA%JmEyv>n=EGi9OwJ@_Fr)) zSSPy(rJ+J~_F$AdNHbr{EUMt8;o8}tO7(I%K^qfXKXsI(8S0X=IB7|z#Fwzj1Gao0 zC-f^d-#%jD=-0_O7VcEZw`8x2p=SlJNc!*)1fqd(tA1!sYbSH2xp!%by4PNo@MXa* zRro6@$93ONFU4)yrGB(rU$X8MXBf&xktk2Td5*ikIN7iDw0?CaRf4;B76Beehb)iS z3+MUKKKjchD{*TW9qY+eD&6F6F+O8yvtXG^v?1E*<5I96{=cAKmcS5nMyHv-SM!F| zQ?0NlhB{&X)O^tzcq$#9v=8fL$ZSqKUFon`hO2lr?(xTfUFM80N+nM6Xf(&l$w}kd z47Q)jS)h$^vnWu}O&Sp$%P_1Lck^$A2qs-ZopUy%2=FIbvKp4o-1eYUnxiu`QuS^2 zd<&Mc-v%U%5~#xVgxBwG5-T0jSOL)FjZjnhpZ~*u;D`R-p3ZHhVIYd44+TL8^iTAw z{E`~66j7tCctHe9z4gHhqA$g-87KJ**WPEwmvS;q(&=PoaxVLv&8PH7b6xgF9o|Cy zi-(TA27yC%x)5S%sD;k|1Ni|exTKCv7)2$9Gv+Xv)SL!&51Ix9yO1YbhZYn&8Z8V3 zw?*%DVf3*TU{(Jh3C~s*C}M$+2wE1PJ!d-ETysXCq%mco5%_l%0;jvhqr_SUFX$nK z+rZpYrc%@cy_u?S_EO@X2pgWg0y!Q5LdD>8`B}T9T3qu{Likt7<#EOapYcx0Kdr$> zczxHbFRnuHIwd7cE&O6FJ=ds(WT=AZEQfXXD2H^?@g01sauI2)NYN6*14ijfdcw8Q zkUu54QeufZR_S`9Jd={5TCgsK3ofa+lx{DE?~fr!z-TN6zlxa@PRI%R!DHd3q0trm zt?_Fur1?JIdZ%R=RcbdKYL0RM_~D@JWY9lodkg_yV-gFwPbE{ zZN7Soz)s%+ePQCkJ0tB;eGse_OK3*RE_!M(Kq+byz`+Ax&CERY35CUx7phLKBIyIy zQLIe8u}ocXG93fhu-3DP wTo3}U5qvXt0AG1B;nV<~PFv6Lk8$ z9IF$B!@W zzx&&3{g^{z4Oo-E`0{UmxclD!hr=fH6WU<+22OA$wNZ{w1E>XPRB)_vgx!Mb)#V|% zJTBM91}T7js?_C}`}pw(=bVEf&K|U+0`LYxKCHOr7kcCYgv|*l8@}`&I~<$4Ke@IK zzc@@S_a0t5iT-uGJJW|v-eiiatTL9WXp-*^YS=xyuP*-q#y&l4@nsWACnyk zwQkOtHkgQrYy{1fs42nfD|E>I0W{96dm%qPtZ2Ucdp9c!u z2>@y};Q8?${cd;BbYqlob@mUO0C4u+XjIg~C#XV-1zzAd6^@R{MBM6=)5A|VyMyP) zyQcf+4)@cy-|V|p8e`8B31sft;4}c|5x^N0Lli)24Y-7T;INc6hsfwY#iNB<29y9a z{XoW$2*4>v)=n!*eE&hXA6#obou|iVjX%D)GdhB6ZJ^rglwk(I`MU)kt9$ih=di>w znKqicKOY@OFdP9I10Z*VQ~+x#40#JoqyZEKjP?(zP;10#$e~t~q67dH5dbLzug4gW z?8C;#&Zq$9yczj0e^Vx$Va?MWH{!^f9i3ZAp#oUdq&9L~)R9oaUH;RZlJ5vSO6TawN@Yb zeU84YM0H94NNogu$Gg!Gd1%!sA^?00-~a;0sSvRt5diYFVPhW**8t-Xxl8yl7-WEb zNCn_pCWnW(LGF|z)EYhG^o)d0C;_Jdx8So@0+j*VwKfuv<4gp=Iw5eb$fqyS&<_C4 zUf`;qw@<>c#21pU8Iyw1G1b19ONNASV?LfyvFZgL43W+~``!Lkdm<>@!+A zx&)8USx^B?9!0`IX?(n;jS9|a?nN~M8#^4;f=p2WslC9>In!Y%5mLwj4cQz=FdTu# z#~dPn-HY0gS^#s1m_q~r84yx%0tj+aYg0FJsyjj|fJPN2y^x}4;560o0a)b-Us92B(@>+@#mNUjk zaWV##$wV4}bJ&?9AILOMWCUmo6%hcbHLy0B`?h_re%K1{I+nz)WpeY@EYu=pKpOya zh|sFvM#32Z3>(#mjA5gKV}K99H%F-807$KkOnd+s5&@86{Csx$t+qKt4DfMk6bU{d zH7a2;+fIoqps#!)hePD%KvN42XC)CDt0Ko`)Ab)VAqUuplmY!X6(Xp@63cQo#=$jt zEka|+Dau0(QUEwa0MFk!o^t2DjQ|HZz$qdIR;_WY3XyODYWlIWRpH?JFxiku@PPwx z8>B`6AHX;A!#8D~9eyEs|J`dp=ndYZ_M2~B_<9wpzuEq|f$U3JqiQ&zvAt4Wl8>DC z+KM$F#s+dCYy~L*DF7*eQ$#ibN&svaQYMcgI0Hs<$jAYt0HahOV;BH(0Axo<-2!N; z!vWBk=N3{Xhsy%zg;f|DnC!<+x9}Z$nx{iIMhajRsqYqSY#uqF*_;vi77Rl_O$rX+ z=6VjNRZ|qss8v)OK~t2c7bX)Kt(rFWye6mBAiue-aP{zn)nrF*VfqTUbpJX1=#Ifp zhk&r3NinpD1Q47<3PWEj8x`^ftQiKCm)~n&3~!o;6Z(PE0LD)+Bn7}ZJdZaS6)AvI zL;~RavQJQjo`Z^%4>u_T8uLs-$||0L;XujzOPIznQ5<~#_jKF9h zgYpTY1 zjpN@+(jfzT%^4Mj0H=uT4k!xPt3>ln&!$#r%&Cp{gEsU4XnKZYLs9{Rp|#P_a-P>Y zMTa2;2N0TvGXkUPjdNqTarBOE@v44KUlR)>+&=u(Hc|ku$C_OShy8;gDS$_UAFBpA zJ~szk{mmDt-|{zQMy*lJ03ZL>j#LBuXWtrC;Lxf$Ybkss(UEi2z7#1#f@d zN-Y4jnzw&eAzuTbXk!@q2`QYztQF+VN{ z)f2SZGjN7`ZuTH%6^yZ3EdpMAmk=ogJY>04y)f z35N6>I2bk@$PK`dHX2o!v616Ur~pB=aas82-r=8Qe{yaK+EfpXZIt6|pKwXzo)^+1h$B@f0kr8mKj*JWg+0Rk}`Rl)a{K3T? zuW3=)&=JUhuz_nSoEyNv20S@TTK68Fo3(5hrVEqzig>pW0Z_{kQX^oW=KbKt5|(qA zc$4MEhycRMiV{o(L;&~%@Fw|NzA3|Ypfhka-tm$Os^;Vi8*S(2fV;6Ka;Dbl&#gPP z7I7QF2eLs>-9n0fsMSQ8;g4Yh!svBBjUk-9Z(XfF{ z2{`uQTy{}Zfba$t&MknZRoZBt;b9m6h6ds86Ee<308Pr5b<2Z~$wUko5>7eN6h&pS zBai`b4B!O7D~P}esf}}8aBfm;av;O;*-Z+$Bcw*4nF^@o(d)Ak@~Dk9GPMl64=E8w z0zQxdqyVH0IQ2|+#PSdc@{n3H3@TE-zn~44#E?h8gAI8|!2z(;2hbs$09cPTDu!0m zVPhB?+nAy{yJcjEyisv7q5>f2$DiSg8<78f_jrGM?lG#JQLUvKIh+7Ob*oG!0@&;E RFoV`@G?8z<{PN*XzXGC--R1xQ literal 0 HcmV?d00001 diff --git a/mmpretrain/models/multimodal/ram/gradio_demo.py b/mmpretrain/models/multimodal/ram/gradio_demo.py new file mode 100644 index 0000000..206e6b4 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/gradio_demo.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import gradio as gr +import torch + +from mmpretrain.registry import MODELS, TRANSFORMS +from .config.ram_swin_large_14m import get_ram_cfg, test_transforms_cfg +from .run.inference import inference + +parser = argparse.ArgumentParser( + description='RAM(Recognize Anything Model) demo') +parser.add_argument( + 'ram_ckpt', type=str, help='pretrained file for ram (absolute path)') +parser.add_argument( + 'clip_ckpt', + type=str, + help='clip vit-base-p16 pretrained file (absolute path)') +args = parser.parse_args() + +if torch.cuda.is_available(): + devices = [ + torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count()) + ] +elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + devices = [torch.device('mps')] +else: + devices = [torch.device('cpu')] + + +def get_free_device(): + if hasattr(torch.cuda, 'mem_get_info'): + free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices] + select = max(zip(free, range(len(free))))[1] + else: + import random + select = random.randint(0, len(devices) - 1) + return devices[select] + + +device = get_free_device() + + +def ram_inference(image, tag_list, mode, threshold): + test_transforms = TRANSFORMS.get('Compose')(transforms=test_transforms_cfg) + model = MODELS.build(get_ram_cfg(mode=mode)) + model.load_state_dict(torch.load(args.ram_ckpt)) + model.device = device + + if mode == 'openset': + categories = tag_list + if categories != '': + categories = categories.strip().split() + else: + categories = None + model.set_openset( + categories=categories, + clip_ckpt=args.clip_ckpt, + threshold=threshold) + + sample = dict(img=image) + result = inference(sample, model, test_transforms, mode=mode) + tag, tag_chinese, logits = \ + result.get('tag_output')[0][0], result.get('tag_output')[1][0],\ + result.get('logits_output')[0] + + def wrap(tags, logits): + if tags is None: + return 'Openset mode has no tag_en' + tag_lst = tags.split('|') + rt_lst = [] + for i, tag in enumerate(tag_lst): + tag = tag.strip() + rt_lst.append(tag + f': {logits[i]:.2f}') + return ' | '.join(rt_lst) + + return [wrap(tag, logits), wrap(tag_chinese, logits)] + + +def build_gradio(): + inputs = [ + gr.components.Image(label='image'), + gr.components.Textbox( + lines=2, + label='tag_list', + placeholder= + 'please input the categories split by keyboard "blank": ', + value=''), + gr.components.Radio(['normal', 'openset'], + label='mode', + value='normal'), + gr.components.Slider( + minimum=0, maximum=1, value=0.68, step=0.01, label='threshold') + ] + return gr.Interface( + fn=ram_inference, + inputs=inputs, + outputs=[ + gr.components.Textbox(), + gr.components.Textbox(info="it's translated from the english tags") + ]) + + +def main(): + build_gradio().launch() + + +if __name__ == '__main__': + main() diff --git a/mmpretrain/models/multimodal/ram/openset_utils.py b/mmpretrain/models/multimodal/ram/openset_utils.py new file mode 100644 index 0000000..5fa0f52 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/openset_utils.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmpretrain.registry import MODELS + + +def article(name): + return 'an' if name[0] in 'aeiou' else 'a' + + +def processed_name(name, rm_dot=False): + # _ for lvis + # / for obj365 + res = name.replace('_', ' ').replace('/', ' or ').lower() + if rm_dot: + res = res.rstrip('.') + return res + + +single_template = ['a photo of a {}.'] + +multiple_templates = [ + 'There is {article} {} in the scene.', + 'There is the {} in the scene.', + 'a photo of {article} {} in the scene.', + 'a photo of the {} in the scene.', + 'a photo of one {} in the scene.', + 'itap of {article} {}.', + 'itap of my {}.', # itap: I took a picture of + 'itap of the {}.', + 'a photo of {article} {}.', + 'a photo of my {}.', + 'a photo of the {}.', + 'a photo of one {}.', + 'a photo of many {}.', + 'a good photo of {article} {}.', + 'a good photo of the {}.', + 'a bad photo of {article} {}.', + 'a bad photo of the {}.', + 'a photo of a nice {}.', + 'a photo of the nice {}.', + 'a photo of a cool {}.', + 'a photo of the cool {}.', + 'a photo of a weird {}.', + 'a photo of the weird {}.', + 'a photo of a small {}.', + 'a photo of the small {}.', + 'a photo of a large {}.', + 'a photo of the large {}.', + 'a photo of a clean {}.', + 'a photo of the clean {}.', + 'a photo of a dirty {}.', + 'a photo of the dirty {}.', + 'a bright photo of {article} {}.', + 'a bright photo of the {}.', + 'a dark photo of {article} {}.', + 'a dark photo of the {}.', + 'a photo of a hard to see {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of {article} {}.', + 'a low resolution photo of the {}.', + 'a cropped photo of {article} {}.', + 'a cropped photo of the {}.', + 'a close-up photo of {article} {}.', + 'a close-up photo of the {}.', + 'a jpeg corrupted photo of {article} {}.', + 'a jpeg corrupted photo of the {}.', + 'a blurry photo of {article} {}.', + 'a blurry photo of the {}.', + 'a pixelated photo of {article} {}.', + 'a pixelated photo of the {}.', + 'a black and white photo of the {}.', + 'a black and white photo of {article} {}.', + 'a plastic {}.', + 'the plastic {}.', + 'a toy {}.', + 'the toy {}.', + 'a plushie {}.', + 'the plushie {}.', + 'a cartoon {}.', + 'the cartoon {}.', + 'an embroidered {}.', + 'the embroidered {}.', + 'a painting of the {}.', + 'a painting of a {}.', +] + +openimages_rare_unseen = [ + 'Aerial photography', 'Aircraft engine', 'Ale', 'Aloe', 'Amphibian', + 'Angling', 'Anole', 'Antique car', 'Arcade game', 'Arthropod', + 'Assault rifle', 'Athletic shoe', 'Auto racing', 'Backlighting', + 'Bagpipes', 'Ball game', 'Barbecue chicken', 'Barechested', 'Barquentine', + 'Beef tenderloin', 'Billiard room', 'Billiards', 'Bird of prey', + 'Black swan', 'Black-and-white', 'Blond', 'Boating', 'Bonbon', + 'Bottled water', 'Bouldering', 'Bovine', 'Bratwurst', 'Breadboard', + 'Briefs', 'Brisket', 'Brochette', 'Calabaza', 'Camera operator', 'Canola', + 'Childbirth', 'Chordophone', 'Church bell', 'Classical sculpture', + 'Close-up', 'Cobblestone', 'Coca-cola', 'Combat sport', 'Comics', + 'Compact car', 'Computer speaker', 'Cookies and crackers', + 'Coral reef fish', 'Corn on the cob', 'Cosmetics', 'Crocodilia', + 'Digital camera', 'Dishware', 'Divemaster', 'Dobermann', 'Dog walking', + 'Domestic rabbit', 'Domestic short-haired cat', 'Double-decker bus', + 'Drums', 'Electric guitar', 'Electric piano', 'Electronic instrument', + 'Equestrianism', 'Equitation', 'Erinaceidae', 'Extreme sport', 'Falafel', + 'Figure skating', 'Filling station', 'Fire apparatus', 'Firearm', + 'Flatbread', 'Floristry', 'Forklift truck', 'Freight transport', + 'Fried food', 'Fried noodles', 'Frigate', 'Frozen yogurt', 'Frying', + 'Full moon', 'Galleon', 'Glacial landform', 'Gliding', 'Go-kart', 'Goats', + 'Grappling', 'Great white shark', 'Gumbo', 'Gun turret', 'Hair coloring', + 'Halter', 'Headphones', 'Heavy cruiser', 'Herding', 'High-speed rail', + 'Holding hands', 'Horse and buggy', 'Horse racing', 'Hound', + 'Hunting knife', 'Hurdling', 'Inflatable', 'Jackfruit', 'Jeans', 'Jiaozi', + 'Junk food', 'Khinkali', 'Kitesurfing', 'Lawn game', 'Leaf vegetable', + 'Lechon', 'Lifebuoy', 'Locust', 'Lumpia', 'Luxury vehicle', 'Machine tool', + 'Medical imaging', 'Melee weapon', 'Microcontroller', 'Middle ages', + 'Military person', 'Military vehicle', 'Milky way', 'Miniature Poodle', + 'Modern dance', 'Molluscs', 'Monoplane', 'Motorcycling', 'Musical theatre', + 'Narcissus', 'Nest box', 'Newsagent\'s shop', 'Nile crocodile', + 'Nordic skiing', 'Nuclear power plant', 'Orator', 'Outdoor shoe', + 'Parachuting', 'Pasta salad', 'Peafowl', 'Pelmeni', 'Perching bird', + 'Performance car', 'Personal water craft', 'Pit bull', 'Plant stem', + 'Pork chop', 'Portrait photography', 'Primate', 'Procyonidae', + 'Prosciutto', 'Public speaking', 'Racewalking', 'Ramen', + 'Rear-view mirror', 'Residential area', 'Ribs', 'Rice ball', + 'Road cycling', 'Roller skating', 'Roman temple', 'Rowing', 'Rural area', + 'Sailboat racing', 'Scaled reptile', 'Scuba diving', 'Senior citizen', + 'Shallot', 'Shinto shrine', 'Shooting range', 'Siberian husky', 'Sledding', + 'Soba', 'Solar energy', 'Sport climbing', 'Sport utility vehicle', + 'Steamed rice', 'Stemware', 'Sumo', 'Surfing Equipment', 'Team sport', + 'Touring car', 'Toy block', 'Trampolining', 'Underwater diving', + 'Vegetarian food', 'Wallaby', 'Water polo', 'Watercolor paint', 'Whiskers', + 'Wind wave', 'Woodwind instrument', 'Yakitori', 'Zeppelin' +] + + +def get_clip_model(): + model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + pre_norm=True, + ), + projection=dict( + type='CLIPProjection', in_channels=768, out_channels=512), + text_backbone=dict( + type='CLIPTransformer', + width=512, + layers=12, + heads=8, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-base-patch16', + use_fast=False), + vocab_size=49408, + transformer_width=512, + proj_dim=512, + context_length=77, + data_preprocessor=dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, + ), + ) + return MODELS.build(model) + + +def build_openset_label_embedding(categories=None, clip_ckpt_path=''): + if categories is None: + print('Categories is None, so using rare_unseen categories') + categories = openimages_rare_unseen + model = get_clip_model() + model.load_state_dict(torch.load(clip_ckpt_path)) + templates = multiple_templates + + run_on_gpu = torch.cuda.is_available() + + with torch.no_grad(): + openset_label_embedding = [] + for category in categories: + texts = [ + template.format( + processed_name(category, rm_dot=True), + article=article(category)) for template in templates + ] + texts = [ + 'This is ' + text + if text.startswith('a') or text.startswith('the') else text + for text in texts + ] + texts = model.tokenize(texts) # tokenize + if run_on_gpu: + texts = texts.cuda() + model = model.cuda() + text_embeddings = model.extract_text_feat(texts) + text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) + text_embedding = text_embeddings.mean(dim=0) + text_embedding /= text_embedding.norm() + openset_label_embedding.append(text_embedding) + openset_label_embedding = torch.stack(openset_label_embedding, dim=1) + if run_on_gpu: + openset_label_embedding = openset_label_embedding.cuda() + + openset_label_embedding = openset_label_embedding.t() + return openset_label_embedding, categories diff --git a/mmpretrain/models/multimodal/ram/ram.py b/mmpretrain/models/multimodal/ram/ram.py new file mode 100644 index 0000000..c5d22f0 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/ram.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import pickle +from abc import abstractmethod +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from .bert import BertConfig, BertLMHeadModel, BertModel +from .openset_utils import build_openset_label_embedding +from .utils import tie_encoder_decoder_weights + + +def get_path(path): + file_path = os.path.abspath(os.path.dirname(__file__)) + if not os.path.isabs(path): + return os.path.join(file_path, path) + + +class RAM(BaseModel): + """The implementation of `RAM `_.""" + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.device = device + # build the visual encoder + self.visual_encoder = MODELS.build(vision_backbone) + + # build the tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + self.tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + self.tokenizer.add_special_tokens( + {'additional_special_tokens': ['[ENC]']}) + self.tokenizer.enc_token_id = \ + self.tokenizer.additional_special_tokens_ids[0] + + # build the tag encoder + # encoder_config = BertConfig.from_json_file(med_config) + # encoder_config.encoder_width = 512 + encoder_config = BertConfig.from_dict(tag_encoder) + self.tag_encoder = BertModel( + config=encoder_config, add_pooling_layer=False) + + # build image-tag-text decoder + # decoder_config = BertConfig.from_json_file(med_config) + decoder_config = BertConfig.from_dict(text_decoder) + self.text_decoder = BertLMHeadModel(config=decoder_config) + + self.delete_tag_index = delete_tag_index + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + + # load tag list + self.tag_list = self.load_tag_list(get_path(tag_list)) + self.tag_list_chinese = self.load_tag_list(get_path(tag_list_chinese)) + + # create image-tag recognition decoder + self.threshold = threshold + self.num_class = len(self.tag_list) + # q2l_config = \ + # BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') + # q2l_config.encoder_width = 512 + q2l_config = BertConfig.from_dict(tagging_head) + self.tagging_head = BertModel( + config=q2l_config, add_pooling_layer=False) + self.tagging_head.resize_token_embeddings(len(self.tokenizer)) + self.label_embed = nn.Parameter( + torch.zeros(self.num_class, q2l_config.encoder_width)) + + if q2l_config.hidden_size != 512: + self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size) + else: + self.wordvec_proj = nn.Identity() + + self.fc = nn.Linear(q2l_config.hidden_size, 1) + + self.del_selfattention() + + # share weights of the lowest 2-layer of + # "image-tag interaction encoder" with + # the "image-tag recogntion decoder" + tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', + ' ') + self.image_proj = nn.Linear(vision_width, 512) + # self.label_embed = nn.Parameter(torch.load( + # f'{CONFIG_PATH}/data/textual_label_embedding.pth', + # map_location='cpu').float()) + + # adjust thresholds for some tags + self.class_threshold = torch.ones(self.num_class) * self.threshold + ram_class_threshold_path = get_path( + './data/ram_tag_list_threshold.pickle') + with open(ram_class_threshold_path, 'rb') as f: + ram_class_threshold = pickle.load(f) + for key, value in enumerate(ram_class_threshold): + self.class_threshold[key] = value + + def load_tag_list(self, tag_list_file): + with open(tag_list_file, 'rb') as f: + tag_list = pickle.load(f) + tag_list = np.array(tag_list) + return tag_list + + # delete self-attention layer of image-tag recognition decoder + # to reduce computation, follower Query2Label + def del_selfattention(self): + del self.tagging_head.embeddings + for layer in self.tagging_head.encoder.layer: + del layer.attention + + def get_label_embed(self): + return torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) + + def extract_visual_feature(self, images): + image_embeds = self.visual_encoder(images)[0] + image_embeds = image_embeds.flatten(2, 3) + attn_pool = nn.AdaptiveAvgPool1d(1) + cls_token = attn_pool(image_embeds).permute(0, 2, 1).contiguous() + image_embeds = image_embeds.permute(0, 2, 1).contiguous() + image_embeds = torch.cat([cls_token, image_embeds], dim=1) + image_embeds = self.image_proj(image_embeds) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(images.device) + return image_embeds, image_atts + + def image2tag(self, label_embed, image_embeds, image_atts): + # recognized image tags using image-tag recogntiion decoder + # image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + bs = image_spatial_embeds.shape[0] + label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', + ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + return logits + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + @abstractmethod + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + raise NotImplementedError + + +@MODELS.register_module() +class RAMNormal(RAM): + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__( + tokenizer, + vision_backbone, + tag_encoder, + tagging_head, + text_decoder, + device, + vision_width, + prompt, + threshold, + delete_tag_index, + tag_list, + tag_list_chinese, + data_preprocessor, + init_cfg, + ) + + def tag_process(self, logits): + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(logits.device), + torch.tensor(1.0).to(logits.device), + torch.zeros(self.num_class).to(logits.device)) + + tag = targets.cpu().numpy() + tag[:, self.delete_tag_index] = 0 + tag_output = [] + tag_output_chinese = [] + logits_output = [] + + bs = logits.shape[0] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + logits_output.append( + torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy()) + tag_output.append(' | '.join(token)) + token_chinese = self.tag_list_chinese[index].squeeze(axis=1) + tag_output_chinese.append(' | '.join(token_chinese)) + + return [(tag_output, tag_output_chinese), logits_output] + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + self.eval() + self.to(self.device) + images = images.to(self.device) + label_embed = self.get_label_embed() + image_embeds, image_atts = self.extract_visual_feature(images) + logits = self.image2tag(label_embed, image_embeds, image_atts) + tag_output, logits_output = self.tag_process(logits) + data_samples.set_field(logits_output, 'logits_output') + data_samples.set_field(tag_output, 'tag_output') + return data_samples + + +@MODELS.register_module() +class RAMOpenset(RAMNormal): + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__( + tokenizer, + vision_backbone, + tag_encoder, + tagging_head, + text_decoder, + device, + vision_width, + prompt, + threshold, + delete_tag_index, + tag_list, + tag_list_chinese, + data_preprocessor, + init_cfg, + ) + + def set_openset(self, + categories: List[str] = None, + clip_ckpt: str = '', + threshold: float = 0.68): + openset_label_embedding, openset_categories = \ + build_openset_label_embedding( + categories, clip_ckpt + ) + self.tag_list = np.array(openset_categories) + self.label_embed = nn.Parameter(openset_label_embedding.float()) + self.num_class = len(openset_categories) + + # the threshold for unseen categories is often lower + self.class_threshold = torch.ones(self.num_class) * threshold + + def tag_process(self, logits): + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(logits.device), + torch.tensor(1.0).to(logits.device), + torch.zeros(self.num_class).to(logits.device)) + + tag = targets.cpu().numpy() + tag[:, self.delete_tag_index] = 0 + + bs = logits.shape[0] + tag_output = [] + logits_output = [] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + logits_output.append( + torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy()) + tag_output.append(' | '.join(token)) + + return [(tag_output, [None]), logits_output] diff --git a/mmpretrain/models/multimodal/ram/run/__init__.py b/mmpretrain/models/multimodal/ram/run/__init__.py new file mode 100644 index 0000000..ef101fe --- /dev/null +++ b/mmpretrain/models/multimodal/ram/run/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/models/multimodal/ram/run/inference.py b/mmpretrain/models/multimodal/ram/run/inference.py new file mode 100644 index 0000000..da5afcf --- /dev/null +++ b/mmpretrain/models/multimodal/ram/run/inference.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def inference_ram(sample, model): + + with torch.no_grad(): + result = model.test_step(sample) + + return result + + +def inference_ram_openset(sample, model): + with torch.no_grad(): + result = model.test_step(sample) + + return result + + +def inference(sample, model, transforms, mode='normal'): + sample = transforms(sample) + if sample['inputs'].ndim == 3: + sample['inputs'] = sample['inputs'].unsqueeze(dim=0) + assert mode in ['normal', 'openset' + ], 'mode of inference must be "normal" or "openset"' + if mode == 'normal': + return inference_ram(sample, model) + else: + return inference_ram_openset(sample, model) diff --git a/mmpretrain/models/multimodal/ram/utils.py b/mmpretrain/models/multimodal/ram/utils.py new file mode 100644 index 0000000..32cb115 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/utils.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from torch import nn + + +def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, + base_model_prefix: str, skip_key: str): + uninitialized_encoder_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + print(f'''{decoder.__class__} and {encoder.__class__} are not equal. + In this case make sure that + all encoder weights are correctly initialized.''') + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + uninitialized_encoder_weights: List[str], + skip_key: str, + depth=0, + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f'{decoder_pointer} and {encoder_pointer}' + \ + 'have to be of type torch.nn.Module' + if hasattr(decoder_pointer, 'weight') and skip_key not in module_name: + assert hasattr(encoder_pointer, 'weight') + encoder_pointer.weight = decoder_pointer.weight + if hasattr(decoder_pointer, 'bias'): + assert hasattr(encoder_pointer, 'bias') + encoder_pointer.bias = decoder_pointer.bias + print(module_name + ' is tied') + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert (len(encoder_modules) > + 0), f'''Encoder module {encoder_pointer} + does not match decoder module {decoder_pointer}''' + + all_encoder_weights = set([ + module_name + '/' + sub_name + for sub_name in encoder_modules.keys() + ]) + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance( + decoder_modules[decoder_name], + type(encoder_modules[encoder_name])) and len( + encoder_modules) != len(decoder_modules): + # this can happen if the name corresponds to + # the position in a list module list of layers + # in this case the decoder has added a + # cross-attention that the encoder doesn't have + # thus skip this step and + # subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + '''Max depth of recursive function `tie_encoder_to_decoder` reached. + It seems that there is a circular dependency + between two or more `nn.Modules` of your model.''') + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + '/' + name, + uninitialized_encoder_weights, + skip_key, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + '/' + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, + uninitialized_encoder_weights, skip_key) diff --git a/mmpretrain/models/necks/__init__.py b/mmpretrain/models/necks/__init__.py new file mode 100644 index 0000000..2952a69 --- /dev/null +++ b/mmpretrain/models/necks/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .beitv2_neck import BEiTV2Neck +from .cae_neck import CAENeck +from .densecl_neck import DenseCLNeck +from .gap import GlobalAveragePooling +from .gem import GeneralizedMeanPooling +from .hr_fuse import HRFuseScales +from .itpn_neck import iTPNPretrainDecoder +from .linear_neck import LinearNeck +from .mae_neck import ClsBatchNormNeck, MAEPretrainDecoder +from .milan_neck import MILANPretrainDecoder +from .mixmim_neck import MixMIMPretrainDecoder +from .mocov2_neck import MoCoV2Neck +from .nonlinear_neck import NonLinearNeck +from .simmim_neck import SimMIMLinearDecoder +from .spark_neck import SparKLightDecoder +from .swav_neck import SwAVNeck + +__all__ = [ + 'GlobalAveragePooling', + 'GeneralizedMeanPooling', + 'HRFuseScales', + 'LinearNeck', + 'BEiTV2Neck', + 'CAENeck', + 'DenseCLNeck', + 'MAEPretrainDecoder', + 'ClsBatchNormNeck', + 'MILANPretrainDecoder', + 'MixMIMPretrainDecoder', + 'MoCoV2Neck', + 'NonLinearNeck', + 'SimMIMLinearDecoder', + 'SwAVNeck', + 'iTPNPretrainDecoder', + 'SparKLightDecoder', +] diff --git a/mmpretrain/models/necks/beitv2_neck.py b/mmpretrain/models/necks/beitv2_neck.py new file mode 100644 index 0000000..745e387 --- /dev/null +++ b/mmpretrain/models/necks/beitv2_neck.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class BEiTV2Neck(BaseModule): + """Neck for BEiTV2 Pre-training. + + This module construct the decoder for the final prediction. + + Args: + num_layers (int): Number of encoder layers of neck. Defaults to 2. + early_layers (int): The layer index of the early output from the + backbone. Defaults to 9. + backbone_arch (str): Vision Transformer architecture. Defaults to base. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The initialization value for the + learnable scaling of attention and FFN. Defaults to 0.1. + use_rel_pos_bias (bool): Whether to use unique relative position bias, + if False, use shared relative position bias defined in backbone. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'depth': 12, + 'num_heads': 12, + 'feedforward_channels': 3072, + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'depth': 24, + 'num_heads': 16, + 'feedforward_channels': 4096, + }), + } + + def __init__( + self, + num_layers: int = 2, + early_layers: int = 9, + backbone_arch: str = 'base', + drop_rate: float = 0., + drop_path_rate: float = 0., + layer_scale_init_value: float = 0.1, + use_rel_pos_bias: bool = False, + norm_cfg: dict = dict(type='LN', eps=1e-6), + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + + if isinstance(backbone_arch, str): + backbone_arch = backbone_arch.lower() + assert backbone_arch in set(self.arch_zoo), \ + (f'Arch {backbone_arch} is not in default archs ' + f'{set(self.arch_zoo)}') + self.arch_settings = self.arch_zoo[backbone_arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(backbone_arch, dict) and essential_keys <= set( + backbone_arch + ), f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = backbone_arch + + # stochastic depth decay rule + self.early_layers = early_layers + depth = self.arch_settings['depth'] + dpr = np.linspace(0, drop_path_rate, + max(depth, early_layers + num_layers)) + + self.patch_aggregation = nn.ModuleList() + for i in range(early_layers, early_layers + num_layers): + _layer_cfg = dict( + embed_dims=self.arch_settings['embed_dims'], + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value, + window_size=None, + use_rel_pos_bias=use_rel_pos_bias) + self.patch_aggregation.append( + BEiTTransformerEncoderLayer(**_layer_cfg)) + + self.rescale_patch_aggregation_init_weight() + + embed_dims = self.arch_settings['embed_dims'] + _, norm = build_norm_layer(norm_cfg, embed_dims) + self.add_module('norm', norm) + + def rescale_patch_aggregation_init_weight(self): + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.patch_aggregation): + rescale(layer.attn.proj.weight.data, + self.early_layers + layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, + self.early_layers + layer_id + 1) + + def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the latent prediction and final prediction. + + Args: + x (Tuple[torch.Tensor]): Features of tokens. + rel_pos_bias (torch.Tensor): Shared relative position bias table. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - ``x``: The final layer features from backbone, which are normed + in ``BEiTV2Neck``. + - ``x_cls_pt``: The early state features from backbone, which are + consist of final layer cls_token and early state patch_tokens + from backbone and sent to PatchAggregation layers in the neck. + """ + + early_states, x = inputs[0], inputs[1] + x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1) + for layer in self.patch_aggregation: + x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias) + + # shared norm + x, x_cls_pt = self.norm(x), self.norm(x_cls_pt) + + # remove cls_token + x = x[:, 1:] + x_cls_pt = x_cls_pt[:, 1:] + return x, x_cls_pt diff --git a/mmpretrain/models/necks/cae_neck.py b/mmpretrain/models/necks/cae_neck.py new file mode 100644 index 0000000..81fc301 --- /dev/null +++ b/mmpretrain/models/necks/cae_neck.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer +from mmpretrain.registry import MODELS +from ..utils import CrossMultiheadAttention + + +class CAETransformerRegressorLayer(BaseModule): + """Transformer layer for the regressor of CAE. + + This module is different from conventional transformer encoder layer, for + its queries are the masked tokens, but its keys and values are the + concatenation of the masked and unmasked tokens. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): The number of heads in multi-head attention. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + drop_rate (float): The dropout rate. Defaults to 0.0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The init value of gamma. + Defaults to 0.0. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + """ + + def __init__( + self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + num_fcs: int = 2, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + layer_scale_init_value: float = 0.0, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN', eps=1e-6) + ) -> None: + super().__init__() + + # NOTE: cross attention + _, self.norm1_q_cross = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm1_k_cross = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm1_v_cross = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm2_cross = build_norm_layer(norm_cfg, embed_dims, postfix=2) + self.cross_attn = CrossMultiheadAttention( + embed_dims, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop_rate, + proj_drop=drop_rate) + + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=None, + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = DropPath(drop_prob=drop_path_rate) + + if layer_scale_init_value > 0: + self.gamma_1_cross = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + self.gamma_2_cross = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) + else: + self.gamma_1_cross = nn.Parameter( + torch.ones((embed_dims)), requires_grad=False) + self.gamma_2_cross = nn.Parameter( + torch.ones((embed_dims)), requires_grad=False) + + def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor, + pos_q: torch.Tensor, pos_k: torch.Tensor) -> torch.Tensor: + """Forward function.""" + x = x_q + self.drop_path(self.gamma_1_cross * self.cross_attn( + self.norm1_q_cross(x_q + pos_q), + k=self.norm1_k_cross(x_kv + pos_k), + v=self.norm1_v_cross(x_kv))) + x = self.norm2_cross(x) + x = x + self.drop_path(self.gamma_2_cross * self.ffn(x)) + + return x + + +@MODELS.register_module() +class CAENeck(BaseModule): + """Neck for CAE Pre-training. + + This module construct the latent prediction regressor and the decoder + for the latent prediction and final prediction. + + Args: + num_classes (int): The number of classes for final prediction. Defaults + to 8192. + embed_dims (int): The embed dims of latent feature in regressor and + decoder. Defaults to 768. + regressor_depth (int): The number of regressor blocks. Defaults to 6. + decoder_depth (int): The number of decoder blocks. Defaults to 8. + num_heads (int): The number of head in multi-head attention. Defaults + to 12. + mlp_ratio (int): The expand ratio of latent features in MLP. defaults + to 4. + qkv_bias (bool): Whether or not to use qkv bias. Defaults to True. + qk_scale (float, optional): The scale applied to the results of qk. + Defaults to None. + drop_rate (float): The dropout rate. Defaults to 0. + attn_drop_rate (float): The dropout rate in attention block. Defaults + to 0. + norm_cfg (dict): The config of normalization layer. Defaults to + dict(type='LN', eps=1e-6). + layer_scale_init_value (float, optional): The init value of gamma. + Defaults to None. + mask_tokens_num (int): The number of mask tokens. Defaults to 75. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + num_classes: int = 8192, + embed_dims: int = 768, + regressor_depth: int = 6, + decoder_depth: int = 8, + num_heads: int = 12, + mlp_ratio: int = 4, + qkv_bias: bool = True, + qk_scale: float = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_cfg: dict = dict(type='LN', eps=1e-6), + layer_scale_init_value: float = None, + mask_tokens_num: int = 75, + init_cfg: dict = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.num_features = self.embed_dim = embed_dims + self.mask_token_num = mask_tokens_num + + # regressor + regressor_drop_path_rates = [ + x.item() + for x in torch.linspace(0, drop_path_rate, regressor_depth) + ] + self.regressors = nn.ModuleList([ + CAETransformerRegressorLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=regressor_drop_path_rates[i], + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value) + for i in range(regressor_depth) + ]) + + # decoder + decoder_drop_path_rates = [ + x.item() for x in torch.linspace(0, drop_path_rate, decoder_depth) + ] + self.decoders = nn.ModuleList([ + BEiTTransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + layer_scale_init_value=layer_scale_init_value, + window_size=None, + # setting `use_rel_pos_bias` to False ignores the `window_size` + use_rel_pos_bias=False, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=decoder_drop_path_rates[i], + norm_cfg=norm_cfg) for i in range(decoder_depth) + ]) + + _, self.norm_regressor = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + _, self.norm_decoder = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + + self.head = nn.Linear( + embed_dims, num_classes) if num_classes > 0 else nn.Identity() + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + def init_weights(self) -> None: + """Initialization.""" + super().init_weights() + self.apply(self._init_weights) + trunc_normal_(self.mask_token, std=0.02) + trunc_normal_(self.head.weight, std=0.02) + + def _init_weights(self, m: nn.Module) -> None: + """Initialization.""" + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, x_unmasked: torch.Tensor, pos_embed_masked: torch.Tensor, + pos_embed_unmasked: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the latent prediction and final prediction. + + Args: + x_unmasked (torch.Tensor): Features of unmasked tokens. + pos_embed_masked (torch.Tensor): Position embedding of masked + tokens. + pos_embed_unmasked (torch.Tensor): Position embedding of unmasked + tokens. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - ``logits``: Final prediction. + - ``latent_pred``: Latent prediction. + """ + x_masked = self.mask_token.expand(x_unmasked.shape[0], + self.mask_token_num, -1) + # regressor + for regressor in self.regressors: + x_masked = regressor( + x_masked, torch.cat([x_unmasked, x_masked], dim=1), + pos_embed_masked, + torch.cat([pos_embed_unmasked, pos_embed_masked], dim=1)) + x_masked = self.norm_regressor(x_masked) + latent_pred = x_masked + + # decoder + x_masked = x_masked + pos_embed_masked + for decoder in self.decoders: + x_masked = decoder(x_masked, rel_pos_bias=None) + x_masked = self.norm_decoder(x_masked) + + logits = self.head(x_masked) + + return logits, latent_pred diff --git a/mmpretrain/models/necks/densecl_neck.py b/mmpretrain/models/necks/densecl_neck.py new file mode 100644 index 0000000..bee9a23 --- /dev/null +++ b/mmpretrain/models/necks/densecl_neck.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class DenseCLNeck(BaseModule): + """The non-linear neck of DenseCL. + + Single and dense neck in parallel: fc-relu-fc, conv-relu-conv. + Borrowed from the authors' `code `_. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + num_grid (int): The grid size of dense features. Defaults to None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + hid_channels: int, + out_channels: int, + num_grid: Optional[int] = None, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True), + nn.Linear(hid_channels, out_channels)) + + self.with_pool = True if num_grid is not None else False + if self.with_pool: + self.pool = nn.AdaptiveAvgPool2d((num_grid, num_grid)) + self.mlp2 = nn.Sequential( + nn.Conv2d(in_channels, hid_channels, 1), nn.ReLU(inplace=True), + nn.Conv2d(hid_channels, out_channels, 1)) + self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + """Forward function of neck. + + Args: + x (Tuple[torch.Tensor]): feature map of backbone. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - ``avgpooled_x``: Global feature vectors. + - ``x``: Dense feature vectors. + - ``avgpooled_x2``: Dense feature vectors for queue. + """ + assert len(x) == 1 + x = x[0] + + avgpooled_x = self.avgpool(x) + avgpooled_x = self.mlp(avgpooled_x.view(avgpooled_x.size(0), -1)) + + if self.with_pool: + x = self.pool(x) # sxs + x = self.mlp2(x) # sxs: bxdxsxs + avgpooled_x2 = self.avgpool2(x) # 1x1: bxdx1x1 + x = x.view(x.size(0), x.size(1), -1) # bxdxs^2 + avgpooled_x2 = avgpooled_x2.view(avgpooled_x2.size(0), -1) # bxd + return avgpooled_x, x, avgpooled_x2 diff --git a/mmpretrain/models/necks/gap.py b/mmpretrain/models/necks/gap.py new file mode 100644 index 0000000..0877743 --- /dev/null +++ b/mmpretrain/models/necks/gap.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class GlobalAveragePooling(nn.Module): + """Global Average Pooling neck. + + Note that we use `view` to remove extra channel after pooling. We do not + use `squeeze` as it will also remove the batch dimension when the tensor + has a batch dimension of size 1, which can lead to unexpected errors. + + Args: + dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}. + Default: 2 + """ + + def __init__(self, dim=2): + super(GlobalAveragePooling, self).__init__() + assert dim in [1, 2, 3], 'GlobalAveragePooling dim only support ' \ + f'{1, 2, 3}, get {dim} instead.' + if dim == 1: + self.gap = nn.AdaptiveAvgPool1d(1) + elif dim == 2: + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + else: + self.gap = nn.AdaptiveAvgPool3d((1, 1, 1)) + + def init_weights(self): + pass + + def forward(self, inputs): + if isinstance(inputs, tuple): + outs = tuple([self.gap(x) for x in inputs]) + outs = tuple( + [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) + elif isinstance(inputs, torch.Tensor): + outs = self.gap(inputs) + outs = outs.view(inputs.size(0), -1) + else: + raise TypeError('neck inputs should be tuple or torch.tensor') + return outs diff --git a/mmpretrain/models/necks/gem.py b/mmpretrain/models/necks/gem.py new file mode 100644 index 0000000..f5648be --- /dev/null +++ b/mmpretrain/models/necks/gem.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor, nn +from torch.nn import functional as F +from torch.nn.parameter import Parameter + +from mmpretrain.registry import MODELS + + +def gem(x: Tensor, p: Parameter, eps: float = 1e-6, clamp=True) -> Tensor: + if clamp: + x = x.clamp(min=eps) + return F.avg_pool2d(x.pow(p), (x.size(-2), x.size(-1))).pow(1. / p) + + +@MODELS.register_module() +class GeneralizedMeanPooling(nn.Module): + """Generalized Mean Pooling neck. + + Note that we use `view` to remove extra channel after pooling. We do not + use `squeeze` as it will also remove the batch dimension when the tensor + has a batch dimension of size 1, which can lead to unexpected errors. + + Args: + p (float): Parameter value. Defaults to 3. + eps (float): epsilon. Defaults to 1e-6. + clamp (bool): Use clamp before pooling. Defaults to True + p_trainable (bool): Toggle whether Parameter p is trainable or not. + Defaults to True. + """ + + def __init__(self, p=3., eps=1e-6, clamp=True, p_trainable=True): + assert p >= 1, "'p' must be a value greater than 1" + super(GeneralizedMeanPooling, self).__init__() + self.p = Parameter(torch.ones(1) * p, requires_grad=p_trainable) + self.eps = eps + self.clamp = clamp + self.p_trainable = p_trainable + + def forward(self, inputs): + if isinstance(inputs, tuple): + outs = tuple([ + gem(x, p=self.p, eps=self.eps, clamp=self.clamp) + for x in inputs + ]) + outs = tuple( + [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) + elif isinstance(inputs, torch.Tensor): + outs = gem(inputs, p=self.p, eps=self.eps, clamp=self.clamp) + outs = outs.view(inputs.size(0), -1) + else: + raise TypeError('neck inputs should be tuple or torch.tensor') + return outs diff --git a/mmpretrain/models/necks/hr_fuse.py b/mmpretrain/models/necks/hr_fuse.py new file mode 100644 index 0000000..4a97f86 --- /dev/null +++ b/mmpretrain/models/necks/hr_fuse.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn.bricks import ConvModule +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from ..backbones.resnet import Bottleneck, ResLayer + + +@MODELS.register_module() +class HRFuseScales(BaseModule): + """Fuse feature map of multiple scales in HRNet. + + Args: + in_channels (list[int]): The input channels of all scales. + out_channels (int): The channels of fused feature map. + Defaults to 2048. + norm_cfg (dict): dictionary to construct norm layers. + Defaults to ``dict(type='BN', momentum=0.1)``. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to ``dict(type='Normal', layer='Linear', std=0.01))``. + """ + + def __init__(self, + in_channels, + out_channels=2048, + norm_cfg=dict(type='BN', momentum=0.1), + init_cfg=dict(type='Normal', layer='Linear', std=0.01)): + super(HRFuseScales, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.norm_cfg = norm_cfg + + block_type = Bottleneck + out_channels = [128, 256, 512, 1024] + + # Increase the channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + increase_layers = [] + for i in range(len(in_channels)): + increase_layers.append( + ResLayer( + block_type, + in_channels=in_channels[i], + out_channels=out_channels[i], + num_blocks=1, + stride=1, + )) + self.increase_layers = nn.ModuleList(increase_layers) + + # Downsample feature maps in each scale. + downsample_layers = [] + for i in range(len(in_channels) - 1): + downsample_layers.append( + ConvModule( + in_channels=out_channels[i], + out_channels=out_channels[i + 1], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + bias=False, + )) + self.downsample_layers = nn.ModuleList(downsample_layers) + + # The final conv block before final classifier linear layer. + self.final_layer = ConvModule( + in_channels=out_channels[3], + out_channels=self.out_channels, + kernel_size=1, + norm_cfg=self.norm_cfg, + bias=False, + ) + + def forward(self, x): + assert isinstance(x, tuple) and len(x) == len(self.in_channels) + + feat = self.increase_layers[0](x[0]) + for i in range(len(self.downsample_layers)): + feat = self.downsample_layers[i](feat) + \ + self.increase_layers[i + 1](x[i + 1]) + + return (self.final_layer(feat), ) diff --git a/mmpretrain/models/necks/itpn_neck.py b/mmpretrain/models/necks/itpn_neck.py new file mode 100644 index 0000000..1a3626a --- /dev/null +++ b/mmpretrain/models/necks/itpn_neck.py @@ -0,0 +1,388 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.models.backbones.hivit import BlockWithRPE +from mmpretrain.registry import MODELS +from ..backbones.vision_transformer import TransformerEncoderLayer +from ..utils import build_2d_sincos_position_embedding + + +class PatchSplit(nn.Module): + """The up-sample module used in neck (transformer pyramid network) + + Args: + dim (int): the input dimension (channel number). + fpn_dim (int): the fpn dimension (channel number). + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + """ + + def __init__(self, dim, fpn_dim, norm_cfg): + super().__init__() + _, self.norm = build_norm_layer(norm_cfg, dim) + self.reduction = nn.Linear(dim, fpn_dim * 4, bias=False) + self.fpn_dim = fpn_dim + + def forward(self, x): + B, N, H, W, C = x.shape + x = self.norm(x) + x = self.reduction(x) + x = x.reshape(B, N, H, W, 2, 2, + self.fpn_dim).permute(0, 1, 2, 4, 3, 5, + 6).reshape(B, N, 2 * H, 2 * W, + self.fpn_dim) + return x + + +@MODELS.register_module() +class iTPNPretrainDecoder(BaseModule): + """The neck module of iTPN (transformer pyramid network). + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 512. + fpn_dim (int): The fpn dimension (channel number). + fpn_depth (int): The layer number of feature pyramid. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + reconstruction_type (str): The itpn supports 2 kinds of supervisions. + Defaults to 'pixel'. + num_outs (int): The output number of neck (transformer pyramid + network). Defaults to 3. + predict_feature_dim (int): The output dimension to supervision. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 512, + fpn_dim: int = 256, + fpn_depth: int = 2, + decoder_embed_dim: int = 512, + decoder_depth: int = 6, + decoder_num_heads: int = 16, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + reconstruction_type: str = 'pixel', + num_outs: int = 3, + qkv_bias: bool = True, + qk_scale: Optional[bool] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + predict_feature_dim: Optional[float] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.num_patches = num_patches + assert reconstruction_type in ['pixel', 'clip'], \ + 'iTPN method only support `pixel` and `clip`, ' \ + f'but got `{reconstruction_type}`.' + self.reconstruction_type = reconstruction_type + self.num_outs = num_outs + + self.build_transformer_pyramid( + num_outs=num_outs, + embed_dim=embed_dim, + fpn_dim=fpn_dim, + fpn_depth=fpn_depth, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + rpe=False, + norm_cfg=norm_cfg, + ) + + # merge the output + self.decoder_embed = nn.ModuleList() + self.decoder_embed.append( + nn.Sequential( + nn.LayerNorm(fpn_dim), + nn.Linear(fpn_dim, decoder_embed_dim, bias=True), + )) + + if self.num_outs >= 2: + self.decoder_embed.append( + nn.Sequential( + nn.LayerNorm(fpn_dim), + nn.Linear(fpn_dim, decoder_embed_dim // 4, bias=True), + )) + if self.num_outs >= 3: + self.decoder_embed.append( + nn.Sequential( + nn.LayerNorm(fpn_dim), + nn.Linear(fpn_dim, decoder_embed_dim // 16, bias=True), + )) + + if reconstruction_type == 'pixel': + self.mask_token = nn.Parameter( + torch.zeros(1, 1, decoder_embed_dim)) + + # create new position embedding, different from that in encoder + # and is not learnable + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, decoder_embed_dim), + requires_grad=False) + + self.decoder_blocks = nn.ModuleList([ + TransformerEncoderLayer( + decoder_embed_dim, + decoder_num_heads, + int(mlp_ratio * decoder_embed_dim), + qkv_bias=True, + norm_cfg=norm_cfg) for _ in range(decoder_depth) + ]) + + self.decoder_norm_name, decoder_norm = build_norm_layer( + norm_cfg, decoder_embed_dim, postfix=1) + self.add_module(self.decoder_norm_name, decoder_norm) + + # Used to map features to pixels + if predict_feature_dim is None: + predict_feature_dim = patch_size**2 * in_chans + self.decoder_pred = nn.Linear( + decoder_embed_dim, predict_feature_dim, bias=True) + else: + _, norm = build_norm_layer(norm_cfg, embed_dim) + self.add_module('norm', norm) + + def build_transformer_pyramid(self, + num_outs=3, + embed_dim=512, + fpn_dim=256, + fpn_depth=2, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + rpe=False, + norm_cfg=None): + Hp = None + mlvl_dims = {'4': embed_dim // 4, '8': embed_dim // 2, '16': embed_dim} + if num_outs > 1: + if embed_dim != fpn_dim: + self.align_dim_16tofpn = nn.Linear(embed_dim, fpn_dim) + else: + self.align_dim_16tofpn = None + self.fpn_modules = nn.ModuleList() + self.fpn_modules.append( + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg)) + self.fpn_modules.append( + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=False, + norm_cfg=norm_cfg, + )) + + self.align_dim_16to8 = nn.Linear( + mlvl_dims['8'], fpn_dim, bias=False) + self.split_16to8 = PatchSplit(mlvl_dims['16'], fpn_dim, norm_cfg) + self.block_16to8 = nn.Sequential(*[ + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg, + ) for _ in range(fpn_depth) + ]) + + if num_outs > 2: + self.align_dim_8to4 = nn.Linear( + mlvl_dims['4'], fpn_dim, bias=False) + self.split_8to4 = PatchSplit(fpn_dim, fpn_dim, norm_cfg) + self.block_8to4 = nn.Sequential(*[ + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg, + ) for _ in range(fpn_depth) + ]) + self.fpn_modules.append( + BlockWithRPE( + Hp, + fpn_dim, + 0, + mlp_ratio, + qkv_bias, + qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=0., + rpe=rpe, + norm_cfg=norm_cfg)) + + def init_weights(self) -> None: + """Initialize position embedding and mask token of MAE decoder.""" + super().init_weights() + + if self.reconstruction_type == 'pixel': + decoder_pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.decoder_pos_embed.shape[-1], + cls_token=False) + self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) + + torch.nn.init.normal_(self.mask_token, std=.02) + else: + self.rescale_init_weight() + + def rescale_init_weight(self) -> None: + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.fpn_modules): + if isinstance(layer, BlockWithRPE): + if layer.attn is not None: + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + @property + def decoder_norm(self): + """The normalization layer of decoder.""" + return getattr(self, self.decoder_norm_name) + + def forward(self, + x: torch.Tensor, + ids_restore: torch.Tensor = None) -> torch.Tensor: + """The forward function. + + The process computes the visible patches' features vectors and the mask + tokens to output feature vectors, which will be used for + reconstruction. + + Args: + x (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + ids_restore (torch.Tensor): ids to restore original image. + + Returns: + torch.Tensor: The reconstructed feature vectors, which is of + shape B x (num_patches) x C. + """ + + features = x[:2] + x = x[-1] + B, L, _ = x.shape + x = x[..., None, None, :] + Hp = Wp = math.sqrt(L) + + outs = [x] if self.align_dim_16tofpn is None else [ + self.align_dim_16tofpn(x) + ] + if self.num_outs >= 2: + x = self.block_16to8( + self.split_16to8(x) + self.align_dim_16to8(features[1])) + outs.append(x) + if self.num_outs >= 3: + x = self.block_8to4( + self.split_8to4(x) + self.align_dim_8to4(features[0])) + outs.append(x) + if self.num_outs > 3: + outs = [ + out.reshape(B, Hp, Wp, *out.shape[-3:]).permute( + 0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * out.shape[-3], + Wp * out.shape[-2]).contiguous() + for out in outs + ] + if self.num_outs >= 4: + outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2)) + if self.num_outs >= 5: + outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2)) + + for i, out in enumerate(outs): + out = self.fpn_modules[i](out) + outs[i] = out + + if self.reconstruction_type == 'pixel': + feats = [] + for feat, layer in zip(outs, self.decoder_embed): + x = layer(feat).reshape(B, L, -1) + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x = torch.cat([x, mask_tokens], dim=1) + x = torch.gather( + x, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) + feats.append(x) + x = feats.pop(0) + # add pos embed + x = x + self.decoder_pos_embed + + for i, feat in enumerate(feats): + x = x + feats[i] + # apply Transformer blocks + for i, blk in enumerate(self.decoder_blocks): + x = blk(x) + x = self.decoder_norm(x) + x = self.decoder_pred(x) + return x + else: + feats = [] + for feat, layer in zip(outs, self.decoder_embed): + x = layer(feat).reshape(B, L, -1) + feats.append(x) + x = feats.pop(0) + for i, feat in enumerate(feats): + x = x + feats[i] + + x = self.norm(x) + + return x diff --git a/mmpretrain/models/necks/linear_neck.py b/mmpretrain/models/necks/linear_neck.py new file mode 100644 index 0000000..bcdbee2 --- /dev/null +++ b/mmpretrain/models/necks/linear_neck.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class LinearNeck(BaseModule): + """Linear neck with Dimension projection. + + Args: + in_channels (int): Number of channels in the input. + out_channels (int): Number of channels in the output. + gap_dim (int): Dimensions of each sample channel, can be one of + {0, 1, 2, 3}. Defaults to 0. + norm_cfg (dict, optional): dictionary to construct and + config norm layer. Defaults to dict(type='BN1d'). + act_cfg (dict, optional): dictionary to construct and + config activate layer. Defaults to None. + init_cfg (dict, optional): dictionary to initialize weights. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + gap_dim: int = 0, + norm_cfg: Optional[dict] = dict(type='BN1d'), + act_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + self.norm_cfg = copy.deepcopy(norm_cfg) + self.act_cfg = copy.deepcopy(act_cfg) + + assert gap_dim in [0, 1, 2, 3], 'GlobalAveragePooling dim only ' \ + f'support {0, 1, 2, 3}, get {gap_dim} instead.' + if gap_dim == 0: + self.gap = nn.Identity() + elif gap_dim == 1: + self.gap = nn.AdaptiveAvgPool1d(1) + elif gap_dim == 2: + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + elif gap_dim == 3: + self.gap = nn.AdaptiveAvgPool3d((1, 1, 1)) + + self.fc = nn.Linear(in_features=in_channels, out_features=out_channels) + + if norm_cfg: + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + else: + self.norm = nn.Identity() + + if act_cfg: + self.act = build_activation_layer(act_cfg) + else: + self.act = nn.Identity() + + def forward(self, inputs: Union[Tuple, + torch.Tensor]) -> Tuple[torch.Tensor]: + """forward function. + + Args: + inputs (Union[Tuple, torch.Tensor]): The features extracted from + the backbone. Multiple stage inputs are acceptable but only + the last stage will be used. + + Returns: + Tuple[torch.Tensor]: A tuple of output features. + """ + assert isinstance(inputs, (tuple, torch.Tensor)), ( + 'The inputs of `LinearNeck` must be tuple or `torch.Tensor`, ' + f'but get {type(inputs)}.') + if isinstance(inputs, tuple): + inputs = inputs[-1] + + x = self.gap(inputs) + x = x.view(x.size(0), -1) + out = self.act(self.norm(self.fc(x))) + return (out, ) diff --git a/mmpretrain/models/necks/mae_neck.py b/mmpretrain/models/necks/mae_neck.py new file mode 100644 index 0000000..773692d --- /dev/null +++ b/mmpretrain/models/necks/mae_neck.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from ..backbones.vision_transformer import TransformerEncoderLayer +from ..utils import build_2d_sincos_position_embedding + + +@MODELS.register_module() +class MAEPretrainDecoder(BaseModule): + """Decoder for MAE Pre-training. + + Some of the code is borrowed from `https://github.com/facebookresearch/mae`. # noqa + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 1024. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + + Example: + >>> from mmpretrain.models import MAEPretrainDecoder + >>> import torch + >>> self = MAEPretrainDecoder() + >>> self.eval() + >>> inputs = torch.rand(1, 50, 1024) + >>> ids_restore = torch.arange(0, 196).unsqueeze(0) + >>> level_outputs = self.forward(inputs, ids_restore) + >>> print(tuple(level_outputs.shape)) + (1, 196, 768) + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1024, + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + predict_feature_dim: Optional[float] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.num_patches = num_patches + + # used to convert the dim of features from encoder to the dim + # compatible with that of decoder + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + # create new position embedding, different from that in encoder + # and is not learnable + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, decoder_embed_dim), + requires_grad=False) + + self.decoder_blocks = nn.ModuleList([ + TransformerEncoderLayer( + decoder_embed_dim, + decoder_num_heads, + int(mlp_ratio * decoder_embed_dim), + qkv_bias=True, + norm_cfg=norm_cfg) for _ in range(decoder_depth) + ]) + + self.decoder_norm_name, decoder_norm = build_norm_layer( + norm_cfg, decoder_embed_dim, postfix=1) + self.add_module(self.decoder_norm_name, decoder_norm) + + # Used to map features to pixels + if predict_feature_dim is None: + predict_feature_dim = patch_size**2 * in_chans + self.decoder_pred = nn.Linear( + decoder_embed_dim, predict_feature_dim, bias=True) + + def init_weights(self) -> None: + """Initialize position embedding and mask token of MAE decoder.""" + super().init_weights() + + decoder_pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.decoder_pos_embed.shape[-1], + cls_token=True) + self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) + + torch.nn.init.normal_(self.mask_token, std=.02) + + @property + def decoder_norm(self): + """The normalization layer of decoder.""" + return getattr(self, self.decoder_norm_name) + + def forward(self, x: torch.Tensor, + ids_restore: torch.Tensor) -> torch.Tensor: + """The forward function. + + The process computes the visible patches' features vectors and the mask + tokens to output feature vectors, which will be used for + reconstruction. + + Args: + x (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + ids_restore (torch.Tensor): ids to restore original image. + + Returns: + torch.Tensor: The reconstructed feature vectors, which is of + shape B x (num_patches) x C. + """ + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) + x = torch.cat([x[:, :1, :], x_], dim=1) + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + # remove cls token + x = x[:, 1:, :] + + return x + + +@MODELS.register_module() +class ClsBatchNormNeck(BaseModule): + """Normalize cls token across batch before head. + + This module is proposed by MAE, when running linear probing. + + Args: + input_features (int): The dimension of features. + affine (bool): a boolean value that when set to ``True``, this module + has learnable affine parameters. Defaults to False. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-6. + init_cfg (Dict or List[Dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + input_features: int, + affine: bool = False, + eps: float = 1e-6, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg) + self.bn = nn.BatchNorm1d(input_features, affine=affine, eps=eps) + + def forward( + self, + inputs: Tuple[List[torch.Tensor]]) -> Tuple[List[torch.Tensor]]: + """The forward function.""" + # Only apply batch norm to cls_token + inputs = [self.bn(input_) for input_ in inputs] + return tuple(inputs) diff --git a/mmpretrain/models/necks/milan_neck.py b/mmpretrain/models/necks/milan_neck.py new file mode 100644 index 0000000..b48b767 --- /dev/null +++ b/mmpretrain/models/necks/milan_neck.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from torch import nn + +from mmpretrain.registry import MODELS +from ..backbones.vision_transformer import TransformerEncoderLayer +from ..utils import PromptMultiheadAttention +from .mae_neck import MAEPretrainDecoder + + +class PromptTransformerEncoderLayer(TransformerEncoderLayer): + """Prompt Transformer Encoder Layer for MILAN. + + This module is specific for the prompt encoder in MILAN. It will not update + the visible tokens from the encoder. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Defaults to 0.0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): Enable bias for qkv if True. Defaults to True. + act_cfg (dict): The activation config for FFNs. + Defaults to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + feedforward_channels=int, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_fcs: int = 2, + qkv_bias: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + self.attn = PromptMultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias) + + def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor, + ids_restore: torch.Tensor) -> torch.Tensor: + """Forward function for `PromptMultiheadAttention`. + + Args: + x (torch.Tensor): Mask token features with shape N x L_m x C. + visible_tokens (torch.Tensor): The visible tokens features from + encoder with shape N x L_v x C. + ids_restore (torch.Tensor): The ids of all tokens in the original + image with shape N x L. + + Returns: + torch Tensor: Output features with shape N x L x C. + """ + x = x + self.attn(self.norm1(x), visible_tokens, ids_restore) + x = self.ffn(self.norm2(x), identity=x) + return x + + +@MODELS.register_module() +class MILANPretrainDecoder(MAEPretrainDecoder): + """Prompt decoder for MILAN. + + This decoder is used in MILAN pretraining, which will not update these + visible tokens from the encoder. + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 1024. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + predict_feature_dim (int): The dimension of the feature to be + predicted. Defaults to 512. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1024, + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + predict_feature_dim: int = 512, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + num_patches=num_patches, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + decoder_embed_dim=decoder_embed_dim, + decoder_depth=decoder_depth, + decoder_num_heads=decoder_num_heads, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + # map the dim of features from decoder to the dim compatible with + # that of CLIP + self.decoder_pred = nn.Linear( + decoder_embed_dim, predict_feature_dim, bias=True) + + # use prompt transformer encoder layer, instead of the conventional + # transformer encoder layer + self.decoder_blocks = nn.ModuleList([ + PromptTransformerEncoderLayer( + decoder_embed_dim, + decoder_num_heads, + int(mlp_ratio * decoder_embed_dim), + qkv_bias=True, + norm_cfg=norm_cfg) for _ in range(decoder_depth) + ]) + + def forward(self, x: torch.Tensor, ids_restore: torch.Tensor, + ids_keep: torch.Tensor, + ids_dump: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): The input features, which is of shape (N, L, C). + ids_restore (torch.Tensor): The indices to restore these tokens + to the original image. + ids_keep (torch.Tensor): The indices of tokens to be kept. + ids_dump (torch.Tensor): The indices of tokens to be masked. + + Returns: + torch.Tensor: The reconstructed features, which is of shape + (N, L, C). + """ + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) + x = torch.cat([x[:, :1, :], x_], dim=1) + + # add pos embed + x = x + self.decoder_pos_embed + + # split mask tokens and visible tokens + visible_tokens = torch.cat([ + x[:, :1, :], + torch.gather( + x[:, 1:, :], + dim=1, + index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1])) + ], + dim=1) + x = torch.gather( + x[:, 1:, :], + dim=1, + index=ids_dump.unsqueeze(-1).repeat(1, 1, x.shape[-1])) + + for blk in self.decoder_blocks: + x = blk(x, visible_tokens, ids_restore) + + # full sequence recovery + x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1) + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, + x.shape[-1])) # unshuffle + x = torch.cat([visible_tokens[:, :1, :], x_], dim=1) + + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + return x diff --git a/mmpretrain/models/necks/mixmim_neck.py b/mmpretrain/models/necks/mixmim_neck.py new file mode 100644 index 0000000..8d67ee2 --- /dev/null +++ b/mmpretrain/models/necks/mixmim_neck.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from ..utils import build_2d_sincos_position_embedding +from .mae_neck import MAEPretrainDecoder + + +@MODELS.register_module() +class MixMIMPretrainDecoder(MAEPretrainDecoder): + """Decoder for MixMIM Pretraining. + + Some of the code is borrowed from `https://github.com/Sense-X/MixMIM`. # noqa + + Args: + num_patches (int): The number of total patches. Defaults to 196. + patch_size (int): Image patch size. Defaults to 16. + in_chans (int): The channel of input image. Defaults to 3. + embed_dim (int): Encoder's embedding dimension. Defaults to 1024. + encoder_stride (int): The output stride of MixMIM backbone. Defaults + to 32. + decoder_embed_dim (int): Decoder's embedding dimension. + Defaults to 512. + decoder_depth (int): The depth of decoder. Defaults to 8. + decoder_num_heads (int): Number of attention heads of decoder. + Defaults to 16. + mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. + Defaults to 4. + norm_cfg (dict): Normalization layer. Defaults to LayerNorm. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1024, + encoder_stride: int = 32, + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + mlp_ratio: int = 4, + norm_cfg: dict = dict(type='LN', eps=1e-6), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + + super().__init__( + num_patches=num_patches, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + decoder_embed_dim=decoder_embed_dim, + decoder_depth=decoder_depth, + decoder_num_heads=decoder_num_heads, + mlp_ratio=mlp_ratio, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, decoder_embed_dim), + requires_grad=False) + self.decoder_pred = nn.Linear(decoder_embed_dim, encoder_stride**2 * 3) + + def init_weights(self) -> None: + """Initialize position embedding and mask token of MixMIM decoder.""" + super(MAEPretrainDecoder, self).init_weights() + + decoder_pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.decoder_pos_embed.shape[-1], + cls_token=False) + self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) + + torch.nn.init.normal_(self.mask_token, std=.02) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): The input features, which is of shape (N, L, C). + mask (torch.Tensor): The tensor to indicate which tokens a + re masked. + + Returns: + torch.Tensor: The reconstructed features, which is of shape + (N, L, C). + """ + + x = self.decoder_embed(x) + B, L, C = x.shape + + mask_tokens = self.mask_token.expand(B, L, -1) + x1 = x * (1 - mask) + mask_tokens * mask + x2 = x * mask + mask_tokens * (1 - mask) + x = torch.cat([x1, x2], dim=0) + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for idx, blk in enumerate(self.decoder_blocks): + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + return x diff --git a/mmpretrain/models/necks/mocov2_neck.py b/mmpretrain/models/necks/mocov2_neck.py new file mode 100644 index 0000000..9ad9107 --- /dev/null +++ b/mmpretrain/models/necks/mocov2_neck.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MoCoV2Neck(BaseModule): + """The non-linear neck of MoCo v2: fc-relu-fc. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + with_avg_pool (bool): Whether to apply the global + average pooling after backbone. Defaults to True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + hid_channels: int, + out_channels: int, + with_avg_pool: bool = True, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__(init_cfg) + self.with_avg_pool = with_avg_pool + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True), + nn.Linear(hid_channels, out_channels)) + + def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + """Forward function. + + Args: + x (Tuple[torch.Tensor]): The feature map of backbone. + + Returns: + Tuple[torch.Tensor]: The output features. + """ + assert len(x) == 1 + x = x[0] + if self.with_avg_pool: + x = self.avgpool(x) + return (self.mlp(x.view(x.size(0), -1)), ) diff --git a/mmpretrain/models/necks/nonlinear_neck.py b/mmpretrain/models/necks/nonlinear_neck.py new file mode 100644 index 0000000..ef684d3 --- /dev/null +++ b/mmpretrain/models/necks/nonlinear_neck.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class NonLinearNeck(BaseModule): + """The non-linear neck. + + Structure: fc-bn-[relu-fc-bn] where the substructure in [] can be repeated. + For the default setting, the repeated time is 1. + The neck can be used in many algorithms, e.g., SimCLR, BYOL, SimSiam. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + num_layers (int): Number of fc layers. Defaults to 2. + with_bias (bool): Whether to use bias in fc layers (except for the + last). Defaults to False. + with_last_bn (bool): Whether to add the last BN layer. + Defaults to True. + with_last_bn_affine (bool): Whether to have learnable affine parameters + in the last BN layer (set False for SimSiam). Defaults to True. + with_last_bias (bool): Whether to use bias in the last fc layer. + Defaults to False. + with_avg_pool (bool): Whether to apply the global average pooling + after backbone. Defaults to True. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to dict(type='SyncBN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + in_channels: int, + hid_channels: int, + out_channels: int, + num_layers: int = 2, + with_bias: bool = False, + with_last_bn: bool = True, + with_last_bn_affine: bool = True, + with_last_bias: bool = False, + with_avg_pool: bool = True, + norm_cfg: dict = dict(type='SyncBN'), + init_cfg: Optional[Union[dict, List[dict]]] = [ + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + ) -> None: + super(NonLinearNeck, self).__init__(init_cfg) + self.with_avg_pool = with_avg_pool + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.relu = nn.ReLU(inplace=True) + self.fc0 = nn.Linear(in_channels, hid_channels, bias=with_bias) + self.bn0 = build_norm_layer(norm_cfg, hid_channels)[1] + + self.fc_names = [] + self.bn_names = [] + for i in range(1, num_layers): + this_channels = out_channels if i == num_layers - 1 \ + else hid_channels + if i != num_layers - 1: + self.add_module( + f'fc{i}', + nn.Linear(hid_channels, this_channels, bias=with_bias)) + self.add_module(f'bn{i}', + build_norm_layer(norm_cfg, this_channels)[1]) + self.bn_names.append(f'bn{i}') + else: + self.add_module( + f'fc{i}', + nn.Linear( + hid_channels, this_channels, bias=with_last_bias)) + if with_last_bn: + self.add_module( + f'bn{i}', + build_norm_layer( + dict(**norm_cfg, affine=with_last_bn_affine), + this_channels)[1]) + self.bn_names.append(f'bn{i}') + else: + self.bn_names.append(None) + self.fc_names.append(f'fc{i}') + + def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + """Forward function. + + Args: + x (Tuple[torch.Tensor]): The feature map of backbone. + + Returns: + Tuple[torch.Tensor]: The output features. + """ + assert len(x) == 1 + x = x[0] + if self.with_avg_pool: + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc0(x) + x = self.bn0(x) + for fc_name, bn_name in zip(self.fc_names, self.bn_names): + fc = getattr(self, fc_name) + x = self.relu(x) + x = fc(x) + if bn_name is not None: + bn = getattr(self, bn_name) + x = bn(x) + return (x, ) diff --git a/mmpretrain/models/necks/simmim_neck.py b/mmpretrain/models/necks/simmim_neck.py new file mode 100644 index 0000000..cb1e29b --- /dev/null +++ b/mmpretrain/models/necks/simmim_neck.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SimMIMLinearDecoder(BaseModule): + """Linear Decoder For SimMIM pretraining. + + This neck reconstructs the original image from the shrunk feature map. + + Args: + in_channels (int): Channel dimension of the feature map. + encoder_stride (int): The total stride of the encoder. + """ + + def __init__(self, in_channels: int, encoder_stride: int) -> None: + super().__init__() + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=encoder_stride**2 * 3, + kernel_size=1), + nn.PixelShuffle(encoder_stride), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + x = self.decoder(x) + return x diff --git a/mmpretrain/models/necks/spark_neck.py b/mmpretrain/models/necks/spark_neck.py new file mode 100644 index 0000000..ac129da --- /dev/null +++ b/mmpretrain/models/necks/spark_neck.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +def is_pow2n(x): + return x > 0 and (x & (x - 1) == 0) + + +class ConvBlock2x(BaseModule): + """The definition of convolution block.""" + + def __init__(self, + in_channels: int, + out_channels: int, + mid_channels: int, + norm_cfg: dict, + act_cfg: dict, + last_act: bool, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1, bias=False) + self.norm1 = build_norm_layer(norm_cfg, mid_channels) + self.activate1 = MODELS.build(act_cfg) + + self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False) + self.norm2 = build_norm_layer(norm_cfg, out_channels) + self.activate2 = MODELS.build(act_cfg) if last_act else nn.Identity() + + def forward(self, x: torch.Tensor): + out = self.conv1(x) + out = self.norm1(out) + out = self.activate1(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.activate2(out) + return out + + +class DecoderConvModule(BaseModule): + """The convolution module of decoder with upsampling.""" + + def __init__(self, + in_channels: int, + out_channels: int, + mid_channels: int, + kernel_size: int = 4, + scale_factor: int = 2, + num_conv_blocks: int = 1, + norm_cfg: dict = dict(type='SyncBN'), + act_cfg: dict = dict(type='ReLU6'), + last_act: bool = True, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + + assert (kernel_size - scale_factor >= 0) and\ + (kernel_size - scale_factor) % 2 == 0,\ + f'kernel_size should be greater than or equal to scale_factor '\ + f'and (kernel_size - scale_factor) should be even numbers, '\ + f'while the kernel size is {kernel_size} and scale_factor is '\ + f'{scale_factor}.' + + padding = (kernel_size - scale_factor) // 2 + self.upsample = nn.ConvTranspose2d( + in_channels, + in_channels, + kernel_size=kernel_size, + stride=scale_factor, + padding=padding, + bias=True) + + conv_blocks_list = [ + ConvBlock2x( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + norm_cfg=norm_cfg, + last_act=last_act, + act_cfg=act_cfg) for _ in range(num_conv_blocks) + ] + self.conv_blocks = nn.Sequential(*conv_blocks_list) + + def forward(self, x): + x = self.upsample(x) + return self.conv_blocks(x) + + +@MODELS.register_module() +class SparKLightDecoder(BaseModule): + """The decoder for SparK, which upsamples the feature maps. + + Args: + feature_dim (int): The dimension of feature map. + upsample_ratio (int): The ratio of upsample, equal to downsample_raito + of the algorithm. + mid_channels (int): The middle channel of `DecoderConvModule`. Defaults + to 0. + kernel_size (int): The kernel size of `ConvTranspose2d` in + `DecoderConvModule`. Defaults to 4. + scale_factor (int): The scale_factor of `ConvTranspose2d` in + `DecoderConvModule`. Defaults to 2. + num_conv_blocks (int): The number of convolution blocks in + `DecoderConvModule`. Defaults to 1. + norm_cfg (dict): Normalization config. Defaults to dict(type='SyncBN'). + act_cfg (dict): Activation config. Defaults to dict(type='ReLU6'). + last_act (bool): Whether apply the last activation in + `DecoderConvModule`. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + feature_dim: int, + upsample_ratio: int, + mid_channels: int = 0, + kernel_size: int = 4, + scale_factor: int = 2, + num_conv_blocks: int = 1, + norm_cfg: dict = dict(type='SyncBN'), + act_cfg: dict = dict(type='ReLU6'), + last_act: bool = False, + init_cfg: Optional[dict] = [ + dict(type='Kaiming', layer=['Conv2d', 'ConvTranspose2d']), + dict(type='TruncNormal', std=0.02, layer=['Linear']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'LayerNorm', 'SyncBatchNorm']) + ], + ): + super().__init__(init_cfg=init_cfg) + self.feature_dim = feature_dim + + assert is_pow2n(upsample_ratio) + n = round(math.log2(upsample_ratio)) + channels = [feature_dim // 2**i for i in range(n + 1)] + + self.decoder = nn.ModuleList([ + DecoderConvModule( + in_channels=c_in, + out_channels=c_out, + mid_channels=c_in if mid_channels == 0 else mid_channels, + kernel_size=kernel_size, + scale_factor=scale_factor, + num_conv_blocks=num_conv_blocks, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + last_act=last_act) + for (c_in, c_out) in zip(channels[:-1], channels[1:]) + ]) + self.proj = nn.Conv2d( + channels[-1], 3, kernel_size=1, stride=1, bias=True) + + def forward(self, to_dec): + x = 0 + for i, d in enumerate(self.decoder): + if i < len(to_dec) and to_dec[i] is not None: + x = x + to_dec[i] + x = self.decoder[i](x) + return self.proj(x) diff --git a/mmpretrain/models/necks/swav_neck.py b/mmpretrain/models/necks/swav_neck.py new file mode 100644 index 0000000..807ae8b --- /dev/null +++ b/mmpretrain/models/necks/swav_neck.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SwAVNeck(BaseModule): + """The non-linear neck of SwAV: fc-bn-relu-fc-normalization. + + Args: + in_channels (int): Number of input channels. + hid_channels (int): Number of hidden channels. + out_channels (int): Number of output channels. + with_avg_pool (bool): Whether to apply the global average pooling after + backbone. Defaults to True. + with_l2norm (bool): whether to normalize the output after projection. + Defaults to True. + norm_cfg (dict): Dictionary to construct and config norm layer. + Defaults to dict(type='SyncBN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + in_channels: int, + hid_channels: int, + out_channels: int, + with_avg_pool: bool = True, + with_l2norm: bool = True, + norm_cfg: dict = dict(type='SyncBN'), + init_cfg: Optional[Union[dict, List[dict]]] = [ + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + ) -> None: + super().__init__(init_cfg) + self.with_avg_pool = with_avg_pool + self.with_l2norm = with_l2norm + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if out_channels == 0: + self.projection_neck = nn.Identity() + elif hid_channels == 0: + self.projection_neck = nn.Linear(in_channels, out_channels) + else: + self.norm = build_norm_layer(norm_cfg, hid_channels)[1] + self.projection_neck = nn.Sequential( + nn.Linear(in_channels, hid_channels), + self.norm, + nn.ReLU(inplace=True), + nn.Linear(hid_channels, out_channels), + ) + + def forward_projection(self, x: torch.Tensor) -> torch.Tensor: + """Compute projection. + + Args: + x (torch.Tensor): The feature vectors after pooling. + + Returns: + torch.Tensor: The output features with projection or L2-norm. + """ + x = self.projection_neck(x) + if self.with_l2norm: + x = nn.functional.normalize(x, dim=1, p=2) + return x + + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + + Args: + x (List[torch.Tensor]): list of feature maps, len(x) according to + len(num_crops). + + Returns: + torch.Tensor: The projection vectors. + """ + avg_out = [] + for _x in x: + _x = _x[0] + if self.with_avg_pool: + _out = self.avgpool(_x) + avg_out.append(_out) + feat_vec = torch.cat(avg_out) # [sum(num_crops) * N, C] + feat_vec = feat_vec.view(feat_vec.size(0), -1) + output = self.forward_projection(feat_vec) + return output diff --git a/mmpretrain/models/peft/__init__.py b/mmpretrain/models/peft/__init__.py new file mode 100644 index 0000000..9f43e14 --- /dev/null +++ b/mmpretrain/models/peft/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .lora import LoRAModel + +__all__ = [ + 'LoRAModel', +] diff --git a/mmpretrain/models/peft/lora.py b/mmpretrain/models/peft/lora.py new file mode 100644 index 0000000..ae1bae7 --- /dev/null +++ b/mmpretrain/models/peft/lora.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import re +from typing import Any, List + +import torch +from mmengine.logging import print_log +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +class LoRALinear(nn.Module): + r"""Implements LoRA in a linear layer. + + Args: + original_layer (nn.Linear): The linear layer to be finetuned. + alpha (int): The scale factor of LoRA. Defaults to 1. + rank (int): The rank of LoRA. Defaults to 0. + drop_rate (float): The drop out rate for LoRA. Defaults to 0. + + Note: + The forward process of LoRA linear layer is: + + .. math:: + `y = W_0 x + BAx * (\alpha / r)` + + Where :math:`x` is the input, :math:`y` is the output, + :math:`W_0` is the parameter of the original layer, + :math:`A` and :math:`B` are the low-rank decomposition matrixs, + :math: `\alpha` is the scale factor and :math: `r` is the rank. + """ + + def __init__(self, + original_layer: nn.Linear, + alpha: int = 1, + rank: int = 0, + drop_rate: float = 0.): + super(LoRALinear, self).__init__() + in_features = original_layer.in_features + out_features = original_layer.out_features + + self.lora_dropout = nn.Dropout(drop_rate) + self.lora_down = nn.Linear(in_features, rank, bias=False) + self.lora_up = nn.Linear(rank, out_features, bias=False) + self.scaling = alpha / rank + + nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_up.weight) + + self.original_layer = original_layer + + def forward(self, x: torch.Tensor): + out = self.original_layer(x) + + lora_x = self.lora_dropout(x) + lora_out = self.lora_up(self.lora_down(lora_x)) * self.scaling + + return out + lora_out + + +@MODELS.register_module() +class LoRAModel(BaseModule): + """Implements LoRA in a module. + + An PyTorch implement of : `LoRA: Low-Rank Adaptation + of Large Language Models `_ + + Args: + module (dict): The config of the module to be finetuned. See + :mod:`mmpretrain.models` + alpha (int): The scale factor of LoRA. Defaults to 1. + rank (int): The rank of LoRA. Defaults to 0. + drop_rate (float): The drop out rate for LoRA. Defaults to 0. + targets (List[dict]): The target layers to be applied with the LoRA. + Defaults to a empty list. Specify by regular expression or suffix. + + Examples: + >>> model = LoRAModel( + ... module=dict(type='VisionTransformer', arch='b'), + ... alpha=4, + ... rank=4, + ... drop_rate=0.1, + ... targets=[ + ... dict(type='.*qkv'), # regular expression + ... dict(type='proj', alpha=8, rank=8), # suffix + ... ]) + """ + + def __init__(self, + module: dict, + alpha: int = 1, + rank: int = 0, + drop_rate: float = 0., + targets: List[dict] = list()): + + super().__init__() + + module = MODELS.build(module) + module.init_weights() + + self.module = module + self.alpha = alpha + self.rank = rank + self.drop_rate = drop_rate + + assert len(targets) != 0, \ + 'The length of target layers should not be 0.' + + self.targets = targets + + self.applied = False + self.apply_lora() + + if not self.applied: + raise ValueError( + 'No lora layer is replaced. Please check targets.') + + self._set_lora_trainable() + self._register_state_dict_hooks() + + def apply_lora(self): + """Apply LoRA to target layers.""" + module_names = [k for k, _ in self.module.named_modules()] + for module_name in module_names: + for target in self.targets: + target_name = target['type'] + target_alpha = target.get('alpha', self.alpha) + target_rank = target.get('rank', self.rank) + target_drop_rate = target.get('drop_rate', self.drop_rate) + + if re.fullmatch(target_name, module_name) or \ + module_name.endswith(target_name): + current_module = self.module.get_submodule(module_name) + if isinstance(current_module, nn.Linear): + print_log( + f'Set LoRA for {module_name} ' + f'with alpha: {target_alpha}, ' + f'rank: {target_rank}, ' + f'drop rate: {target_drop_rate}', + logger='current') + + self._replace_module(module_name, current_module, + target_alpha, target_rank, + target_drop_rate) + self.applied = True + + def _replace_module(self, module_name: str, current_module: nn.Module, + alpha: int, rank: int, drop_rate: float): + """Replace target layer with LoRA linear layer in place.""" + parent_module_name = '.'.join(module_name.split('.')[:-1]) + parent_module = self.module.get_submodule(parent_module_name) + + target_name = module_name.split('.')[-1] + target_module = LoRALinear(current_module, alpha, rank, drop_rate) + setattr(parent_module, target_name, target_module) + + def _set_lora_trainable(self): + """Set only the lora parameters trainable.""" + for name, param in self.named_parameters(): + if '.lora_' in name: + param.requires_grad = True + else: + param.requires_grad = False + + def _register_state_dict_hooks(self): + """Register state dict hooks. + + Register state dict saving hooks to save only the lora parameters to + the state dict. And register state dict loading hooks to handle the + incompatible keys while loading the state dict. + """ + + def _state_dict_hook(module, state_dict, prefix, local_metadata): + """Save only the lora parameters to the state dict.""" + keys = [k for k, _ in state_dict.items()] + for key in keys: + if '.lora_' not in key: + state_dict.pop(key) + + self._register_state_dict_hook(_state_dict_hook) + + def _load_state_dict_post_hook(module, incompatible_keys): + """Handle the incompatible keys while loading the state dict.""" + missing_keys = incompatible_keys.missing_keys.copy() + for key in missing_keys: + if '.lora_' not in key: + incompatible_keys.missing_keys.remove(key) + + unexpected_keys = incompatible_keys.unexpected_keys.copy() + for key in unexpected_keys: + if '.lora_' not in key: + incompatible_keys.unexpected_keys.remove(key) + + self.register_load_state_dict_post_hook(_load_state_dict_post_hook) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + try: + return super(LoRAModel, self).__getattr__(name) + except AttributeError: + return self.module.__getattribute__(name) diff --git a/mmpretrain/models/retrievers/__init__.py b/mmpretrain/models/retrievers/__init__.py new file mode 100644 index 0000000..593b637 --- /dev/null +++ b/mmpretrain/models/retrievers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseRetriever +from .image2image import ImageToImageRetriever + +__all__ = ['BaseRetriever', 'ImageToImageRetriever'] diff --git a/mmpretrain/models/retrievers/base.py b/mmpretrain/models/retrievers/base.py new file mode 100644 index 0000000..1581679 --- /dev/null +++ b/mmpretrain/models/retrievers/base.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Union + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement +from torch.utils.data import DataLoader + + +class BaseRetriever(BaseModel, metaclass=ABCMeta): + """Base class for retriever. + + Args: + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None, it will use "BaseDataPreprocessor" as type, see + :class:`mmengine.model.BaseDataPreprocessor` for more details. + Defaults to None. + prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be + retrieved. The following four types are supported. + + - DataLoader: The original dataloader serves as the prototype. + - dict: The configuration to construct Dataloader. + - str: The path of the saved vector. + - torch.Tensor: The saved tensor whose dimension should be dim. + + Attributes: + prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be + retrieved. The following four types are supported. + + - DataLoader: The original dataloader serves as the prototype. + - dict: The configuration to construct Dataloader. + - str: The path of the saved vector. + - torch.Tensor: The saved tensor whose dimension should be dim. + + data_preprocessor (:obj:`mmengine.model.BaseDataPreprocessor`): An + extra data pre-processing module, which processes data from + dataloader to the format accepted by :meth:`forward`. + """ + + def __init__( + self, + prototype: Union[DataLoader, dict, str, torch.Tensor] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + ): + super(BaseRetriever, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + self.prototype = prototype + self.prototype_inited = False + + @abstractmethod + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'loss'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor without any + post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor, tuple): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + pass + + def extract_feat(self, inputs: torch.Tensor): + """Extract features from the input tensor with shape (N, C, ...). + + The sub-classes are recommended to implement this method to extract + features from backbone and neck. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + """ + raise NotImplementedError + + def loss(self, inputs: torch.Tensor, + data_samples: List[BaseDataElement]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + raise NotImplementedError + + def predict(self, + inputs: tuple, + data_samples: Optional[List[BaseDataElement]] = None, + **kwargs) -> List[BaseDataElement]: + """Predict results from the extracted features. + + Args: + inputs (tuple): The features extracted from the backbone. + data_samples (List[BaseDataElement], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + """ + raise NotImplementedError + + def matching(self, inputs: torch.Tensor): + """Compare the prototype and calculate the similarity. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C). + """ + raise NotImplementedError + + def prepare_prototype(self): + """Preprocessing the prototype before predict.""" + raise NotImplementedError + + def dump_prototype(self, path): + """Save the features extracted from the prototype to the specific path. + + Args: + path (str): Path to save feature. + """ + raise NotImplementedError diff --git a/mmpretrain/models/retrievers/image2image.py b/mmpretrain/models/retrievers/image2image.py new file mode 100644 index 0000000..a00c1dc --- /dev/null +++ b/mmpretrain/models/retrievers/image2image.py @@ -0,0 +1,314 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, List, Optional, Union + +import mmengine.dist as dist +import torch +import torch.nn as nn +from mmengine.runner import Runner +from torch.utils.data import DataLoader + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process +from .base import BaseRetriever + + +@MODELS.register_module() +class ImageToImageRetriever(BaseRetriever): + """Image To Image Retriever for supervised retrieval task. + + Args: + image_encoder (Union[dict, List[dict]]): Encoder for extracting + features. + prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be + retrieved. The following four types are supported. + + - DataLoader: The original dataloader serves as the prototype. + - dict: The configuration to construct Dataloader. + - str: The path of the saved vector. + - torch.Tensor: The saved tensor whose dimension should be dim. + + head (dict, optional): The head module to calculate loss from + processed features. See :mod:`mmpretrain.models.heads`. Notice + that if the head is not set, `loss` method cannot be used. + Defaults to None. + similarity_fn (Union[str, Callable]): The way that the similarity + is calculated. If `similarity` is callable, it is used directly + as the measure function. If it is a string, the appropriate + method will be used. The larger the calculated value, the + greater the similarity. Defaults to "cosine_similarity". + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in + :mod:`mmpretrain.model.utils.augment`. + + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + topk (int): Return the topk of the retrieval result. `-1` means + return all. Defaults to -1. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + image_encoder: Union[dict, List[dict]], + prototype: Union[DataLoader, dict, str, torch.Tensor], + head: Optional[dict] = None, + pretrained: Optional[str] = None, + similarity_fn: Union[str, Callable] = 'cosine_similarity', + train_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None, + topk: int = -1, + init_cfg: Optional[dict] = None): + + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor') + + if train_cfg is not None and 'augments' in train_cfg: + # Set batch augmentations by `train_cfg` + data_preprocessor['batch_augments'] = train_cfg + + super(ImageToImageRetriever, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if not isinstance(image_encoder, nn.Module): + image_encoder = MODELS.build(image_encoder) + if head is not None and not isinstance(head, nn.Module): + head = MODELS.build(head) + + self.image_encoder = image_encoder + self.head = head + + self.similarity = similarity_fn + + assert isinstance(prototype, (str, torch.Tensor, dict, DataLoader)), ( + 'The `prototype` in `ImageToImageRetriever` must be a path, ' + 'a torch.Tensor, a dataloader or a dataloader dict format config.') + self.prototype = prototype + self.prototype_inited = False + self.topk = topk + + @property + def similarity_fn(self): + """Returns a function that calculates the similarity.""" + # If self.similarity_way is callable, return it directly + if isinstance(self.similarity, Callable): + return self.similarity + + if self.similarity == 'cosine_similarity': + # a is a tensor with shape (N, C) + # b is a tensor with shape (M, C) + # "cosine_similarity" will get the matrix of similarity + # with shape (N, M). + # The higher the score is, the more similar is + return lambda a, b: torch.cosine_similarity( + a.unsqueeze(1), b.unsqueeze(0), dim=-1) + else: + raise RuntimeError(f'Invalid function "{self.similarity_fn}".') + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor without any + post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor, tuple): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor. + - If ``mode="predict"``, return a list of + :obj:`mmpretrain.structures.DataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + return self.extract_feat(inputs) + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs): + """Extract features from the input tensor with shape (N, C, ...). + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + Returns: + Tensor: The output of encoder. + """ + + feat = self.image_encoder(inputs) + return feat + + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + feats = self.extract_feat(inputs) + return self.head.loss(feats, data_samples) + + def matching(self, inputs: torch.Tensor): + """Compare the prototype and calculate the similarity. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C). + Returns: + dict: a dictionary of score and prediction label based on fn. + """ + sim = self.similarity_fn(inputs, self.prototype_vecs) + sorted_sim, indices = torch.sort(sim, descending=True, dim=-1) + predictions = dict( + score=sim, pred_label=indices, pred_score=sorted_sim) + return predictions + + def predict(self, + inputs: tuple, + data_samples: Optional[List[DataSample]] = None, + **kwargs) -> List[DataSample]: + """Predict results from the extracted features. + + Args: + inputs (tuple): The features extracted from the backbone. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + Returns: + List[DataSample]: the raw data_samples with + the predicted results + """ + if not self.prototype_inited: + self.prepare_prototype() + + feats = self.extract_feat(inputs) + if isinstance(feats, tuple): + feats = feats[-1] + + # Matching of similarity + result = self.matching(feats) + return self._get_predictions(result, data_samples) + + def _get_predictions(self, result, data_samples): + """Post-process the output of retriever.""" + pred_scores = result['score'] + pred_labels = result['pred_label'] + if self.topk != -1: + topk = min(self.topk, pred_scores.size()[-1]) + pred_labels = pred_labels[:, :topk] + + if data_samples is not None: + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + data_sample.set_pred_score(score).set_pred_label(label) + else: + data_samples = [] + for score, label in zip(pred_scores, pred_labels): + data_samples.append( + DataSample().set_pred_score(score).set_pred_label(label)) + return data_samples + + def _get_prototype_vecs_from_dataloader(self, data_loader): + """get prototype_vecs from dataloader.""" + self.eval() + num = len(data_loader.dataset) + + prototype_vecs = None + for data_batch in track_on_main_process(data_loader, + 'Prepare prototype'): + data = self.data_preprocessor(data_batch, False) + feat = self(**data) + if isinstance(feat, tuple): + feat = feat[-1] + + if prototype_vecs is None: + dim = feat.shape[-1] + prototype_vecs = torch.zeros(num, dim) + for i, data_sample in enumerate(data_batch['data_samples']): + sample_idx = data_sample.get('sample_idx') + prototype_vecs[sample_idx] = feat[i] + + assert prototype_vecs is not None + dist.all_reduce(prototype_vecs) + return prototype_vecs + + def _get_prototype_vecs_from_path(self, proto_path): + """get prototype_vecs from prototype path.""" + data = [None] + if dist.is_main_process(): + data[0] = torch.load(proto_path) + dist.broadcast_object_list(data, src=0) + prototype_vecs = data[0] + assert prototype_vecs is not None + return prototype_vecs + + @torch.no_grad() + def prepare_prototype(self): + """Used in meta testing. This function will be called before the meta + testing. Obtain the vector based on the prototype. + + - torch.Tensor: The prototype vector is the prototype + - str: The path of the extracted feature path, parse data structure, + and generate the prototype feature vector set + - Dataloader or config: Extract and save the feature vectors according + to the dataloader + """ + device = next(self.image_encoder.parameters()).device + if isinstance(self.prototype, torch.Tensor): + prototype_vecs = self.prototype + elif isinstance(self.prototype, str): + prototype_vecs = self._get_prototype_vecs_from_path(self.prototype) + elif isinstance(self.prototype, (dict, DataLoader)): + loader = Runner.build_dataloader(self.prototype) + prototype_vecs = self._get_prototype_vecs_from_dataloader(loader) + + self.register_buffer( + 'prototype_vecs', prototype_vecs.to(device), persistent=False) + self.prototype_inited = True + + def dump_prototype(self, path): + """Save the features extracted from the prototype to specific path. + + Args: + path (str): Path to save feature. + """ + if not self.prototype_inited: + self.prepare_prototype() + torch.save(self.prototype_vecs, path) diff --git a/mmpretrain/models/selfsup/__init__.py b/mmpretrain/models/selfsup/__init__.py new file mode 100644 index 0000000..08c1ed5 --- /dev/null +++ b/mmpretrain/models/selfsup/__init__.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .barlowtwins import BarlowTwins +from .base import BaseSelfSupervisor +from .beit import VQKD, BEiT, BEiTPretrainViT +from .byol import BYOL +from .cae import CAE, CAEPretrainViT, DALLEEncoder +from .densecl import DenseCL +from .eva import EVA +from .itpn import iTPN, iTPNHiViT +from .mae import MAE, MAEHiViT, MAEViT +from .maskfeat import HOGGenerator, MaskFeat, MaskFeatViT +from .mff import MFF, MFFViT +from .milan import MILAN, CLIPGenerator, MILANViT +from .mixmim import MixMIM, MixMIMPretrainTransformer +from .moco import MoCo +from .mocov3 import MoCoV3, MoCoV3ViT +from .simclr import SimCLR +from .simmim import SimMIM, SimMIMSwinTransformer +from .simsiam import SimSiam +from .spark import SparK +from .swav import SwAV + +__all__ = [ + 'BaseSelfSupervisor', + 'BEiTPretrainViT', + 'VQKD', + 'CAEPretrainViT', + 'DALLEEncoder', + 'MAEViT', + 'MAEHiViT', + 'iTPNHiViT', + 'iTPN', + 'HOGGenerator', + 'MaskFeatViT', + 'CLIPGenerator', + 'MILANViT', + 'MixMIMPretrainTransformer', + 'MoCoV3ViT', + 'SimMIMSwinTransformer', + 'MoCo', + 'MoCoV3', + 'BYOL', + 'SimCLR', + 'SimSiam', + 'BEiT', + 'CAE', + 'MAE', + 'MaskFeat', + 'MILAN', + 'MixMIM', + 'SimMIM', + 'EVA', + 'DenseCL', + 'BarlowTwins', + 'SwAV', + 'SparK', + 'MFF', + 'MFFViT', +] diff --git a/mmpretrain/models/selfsup/barlowtwins.py b/mmpretrain/models/selfsup/barlowtwins.py new file mode 100644 index 0000000..4c75cd0 --- /dev/null +++ b/mmpretrain/models/selfsup/barlowtwins.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class BarlowTwins(BaseSelfSupervisor): + """BarlowTwins. + + Implementation of `Barlow Twins: Self-Supervised Learning via Redundancy + Reduction `_. + Part of the code is borrowed from: + ``_. + """ + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + img_v1 = inputs[0] + img_v2 = inputs[1] + + z1 = self.neck(self.backbone(img_v1))[0] # NxC + z2 = self.neck(self.backbone(img_v2))[0] # NxC + + loss = self.head.loss(z1, z2) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/base.py b/mmpretrain/models/selfsup/base.py new file mode 100644 index 0000000..9d53a72 --- /dev/null +++ b/mmpretrain/models/selfsup/base.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Union + +import torch +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +class BaseSelfSupervisor(BaseModel, metaclass=ABCMeta): + """BaseModel for Self-Supervised Learning. + + All self-supervised algorithms should inherit this module. + + Args: + backbone (dict): The backbone module. See + :mod:`mmpretrain.models.backbones`. + neck (dict, optional): The neck module to process features from + backbone. See :mod:`mmpretrain.models.necks`. Defaults to None. + head (dict, optional): The head module to do prediction and calculate + loss from processed features. See :mod:`mmpretrain.models.heads`. + Notice that if the head is not set, almost all methods cannot be + used except :meth:`extract_feat`. Defaults to None. + target_generator: (dict, optional): The target_generator module to + generate targets for self-supervised learning optimization, such as + HOG, extracted features from other modules(DALL-E, CLIP), etc. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + target_generator: Optional[dict] = None, + pretrained: Optional[str] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): + if pretrained is not None: + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + data_preprocessor = data_preprocessor or {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'SelfSupDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + elif not isinstance(data_preprocessor, nn.Module): + raise TypeError('data_preprocessor should be a `dict` or ' + f'`nn.Module` instance, but got ' + f'{type(data_preprocessor)}') + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if not isinstance(backbone, nn.Module): + backbone = MODELS.build(backbone) + if neck is not None and not isinstance(neck, nn.Module): + neck = MODELS.build(neck) + if head is not None and not isinstance(head, nn.Module): + head = MODELS.build(head) + if target_generator is not None and not isinstance( + target_generator, nn.Module): + target_generator = MODELS.build(target_generator) + + self.backbone = backbone + self.neck = neck + self.head = head + self.target_generator = target_generator + + @property + def with_neck(self) -> bool: + """Check if the model has a neck module.""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_head(self) -> bool: + """Check if the model has a head module.""" + return hasattr(self, 'head') and self.head is not None + + @property + def with_target_generator(self) -> bool: + """Check if the model has a target_generator module.""" + return hasattr( + self, 'target_generator') and self.target_generator is not None + + def forward(self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method currently accepts two modes: "tensor" and "loss": + + - "tensor": Forward the backbone network and return the feature + tensor(s) tensor without any post-processing, same as a common + PyTorch Module. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Args: + inputs (torch.Tensor or List[torch.Tensor]): The input tensor with + shape (N, C, ...) in general. + data_samples (List[DataSample], optional): The other data of + every samples. It's required for some algorithms + if ``mode="loss"``. Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + feats = self.extract_feat(inputs) + return feats + elif mode == 'loss': + return self.loss(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs: torch.Tensor): + """Extract features from the input tensor with shape (N, C, ...). + + The default behavior is extracting features from backbone. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + + Returns: + tuple | Tensor: The output feature tensor(s). + """ + x = self.backbone(inputs) + return x + + @abstractmethod + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + This is a abstract method, and subclass should overwrite this methods + if needed. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + raise NotImplementedError + + def get_layer_depth(self, param_name: str): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + + Returns: + Tuple[int, int]: The layer-wise depth and the max depth. + """ + if hasattr(self.backbone, 'get_layer_depth'): + return self.backbone.get_layer_depth(param_name, 'backbone.') + else: + raise NotImplementedError( + f"The backbone {type(self.backbone)} doesn't " + 'support `get_layer_depth` by now.') diff --git a/mmpretrain/models/selfsup/beit.py b/mmpretrain/models/selfsup/beit.py new file mode 100644 index 0000000..c301f7d --- /dev/null +++ b/mmpretrain/models/selfsup/beit.py @@ -0,0 +1,357 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +from einops import rearrange +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ +from torch import nn + +from mmpretrain.models.backbones import BEiTViT +from mmpretrain.models.utils import NormEMAVectorQuantizer, resize_pos_embed +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class VQKD(BaseModule): + """Vector-Quantized Knowledge Distillation. + + The module only contains encoder and VectorQuantizer part + Modified from https://github.com/microsoft/unilm/blob/master/beit2/modeling_vqkd.py + + Args: + encoder_config (dict): The config of encoder. + decoder_config (dict, optional): The config of decoder. Currently, + VQKD only support to build encoder. Defaults to None. + num_embed (int): Number of embedding vectors in the codebook. Defaults + to 8192. + embed_dims (int) : The dimension of embedding vectors in the codebook. + Defaults to 32. + decay (float): The decay parameter of EMA. Defaults to 0.99. + beta (float): The mutiplier for VectorQuantizer loss. Defaults to 1. + quantize_kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ # noqa: E501 + + def __init__(self, + encoder_config: dict, + decoder_config: Optional[dict] = None, + num_embed: int = 8192, + embed_dims: int = 32, + decay: float = 0.99, + beta: float = 1.0, + quantize_kmeans_init: bool = True, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.encoder = BEiTViT(**encoder_config) + if decoder_config is not None: + self.decoder = BEiTViT(**decoder_config) + + self.quantize = NormEMAVectorQuantizer( + num_embed=num_embed, + embed_dims=embed_dims, + beta=beta, + decay=decay, + kmeans_init=quantize_kmeans_init, + ) + + # task layer + self.encode_task_layer = nn.Sequential( + nn.Linear(self.encoder.arch_settings['embed_dims'], + self.encoder.arch_settings['embed_dims']), nn.Tanh(), + nn.Linear(self.encoder.arch_settings['embed_dims'], embed_dims)) + + def get_tokens(self, x: torch.Tensor) -> dict: + """Get tokens for beit pre-training.""" + _, embed_ind, _ = self.encode(x) + output = {} + output['token'] = embed_ind.view(x.shape[0], -1) + output['input_img'] = x + + return output + + def encode( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode the input images and get corresponding results.""" + encoder_features = self.encoder(x)[0] + B, C, N1, N2 = encoder_features.shape + encoder_features = encoder_features.permute(0, 2, 3, + 1).reshape(B, N1 * N2, C) + + with torch.cuda.amp.autocast(enabled=False): + to_quantizer_features = self.encode_task_layer( + encoder_features.type_as(self.encode_task_layer[-1].weight)) + + N = to_quantizer_features.shape[1] + h, w = int(math.sqrt(N)), int(math.sqrt(N)) + + to_quantizer_features = rearrange( + to_quantizer_features, 'b (h w) c -> b c h w', h=h, + w=w) # reshape for quantizer + quantize, loss, embed_ind = self.quantize(to_quantizer_features) + + return quantize, embed_ind, loss + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The forward function. + + Currently, only support to get tokens. + """ + return self.get_tokens(x)['token'] + + +@MODELS.register_module() +class BEiTPretrainViT(BEiTViT): + """Vision Transformer for BEiT pre-training. + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + use_abs_pos_emb (bool): Whether or not use absolute position embedding. + Defaults to False. + use_rel_pos_bias (bool): Whether or not use relative position bias. + Defaults to False. + use_shared_rel_pos_bias (bool): Whether or not use shared relative + position bias. Defaults to True. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. Defaults to 0.1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: str = 'base', + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + out_indices: int = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + frozen_stages: int = -1, + use_abs_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = True, + layer_scale_init_value: int = 0.1, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(padding=0), + layer_cfgs: dict = dict(), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + frozen_stages=frozen_stages, + use_abs_pos_emb=use_abs_pos_emb, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + use_rel_pos_bias=use_rel_pos_bias, + layer_scale_init_value=layer_scale_init_value, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + trunc_normal_(self.cls_token, std=0.02) + trunc_normal_(self.mask_token, std=0.02) + self.rescale_init_weight() + + def rescale_init_weight(self) -> None: + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.layers): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, layer_id + 1) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor]: + """The BEiT style forward function. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images, which is of shape (B x C x H x W). + mask (torch.Tensor, optional): Mask for input, which is of shape + (B x patch_resolution[0] x patch_resolution[1]). + + Returns: + Tuple[torch.Tensor]: Hidden features. + """ + if mask is None: + return super().forward(x) + + else: + x, patch_resolution = self.patch_embed(x) + + # replace the masked visual tokens by mask_token + B, L, _ = x.shape + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + self.shared_rel_pos_bias = self.rel_pos_bias().to( + mask.device) if self.rel_pos_bias is not None else None + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, rel_pos_bias=self.shared_rel_pos_bias) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + +@MODELS.register_module() +class BEiT(BaseSelfSupervisor): + """BEiT v1/v2. + + Implementation of `BEiT: BERT Pre-Training of Image Transformers + `_ and `BEiT v2: Masked Image Modeling + with Vector-Quantized Visual Tokenizers + `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + + img_latent = self.backbone(inputs[0], mask) + + # inputs[1] is the target image + with torch.no_grad(): + target = self.target_generator(inputs[1]) + target = target.detach() + + if self.with_neck: + # BEiT v2 + feats, feats_cls_pt = self.neck( + img_latent, rel_pos_bias=self.backbone.shared_rel_pos_bias) + loss = self.head.loss(feats, feats_cls_pt, target, mask) + else: + # BEiT v1 + loss = self.head.loss(img_latent[0], target, mask) + + if isinstance(loss, torch.Tensor): + losses = dict(loss=loss) + return losses + elif isinstance(loss, Tuple): + # the loss_1 and loss_2 are general reconstruction loss (patch + # feature vectors from last layer of backbone) and early state + # reconstruction loss (patch feature vectors from intermediate + # layer of backbone) + loss_1, loss_2 = loss[0], loss[1] + losses = dict() + # the key with prefix 'loss', like loss_1 and loss_2, will be used + # as the final criterion + losses['loss_1'] = loss_1 + losses['loss_2'] = loss_2 + return losses diff --git a/mmpretrain/models/selfsup/byol.py b/mmpretrain/models/selfsup/byol.py new file mode 100644 index 0000000..803e400 --- /dev/null +++ b/mmpretrain/models/selfsup/byol.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import CosineEMA +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class BYOL(BaseSelfSupervisor): + """BYOL. + + Implementation of `Bootstrap Your Own Latent: A New Approach to + Self-Supervised Learning `_. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of deep features + to compact feature vectors. + head (dict): Config dict for module of head functions. + base_momentum (float): The base momentum coefficient for the target + network. Defaults to 0.004. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + base_momentum: float = 0.004, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.target_net = CosineEMA( + nn.Sequential(self.backbone, self.neck), momentum=base_momentum) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + img_v1 = inputs[0] + img_v2 = inputs[1] + # compute online features + proj_online_v1 = self.neck(self.backbone(img_v1))[0] + proj_online_v2 = self.neck(self.backbone(img_v2))[0] + # compute target features + with torch.no_grad(): + # update the target net + self.target_net.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + proj_target_v1 = self.target_net(img_v1)[0] + proj_target_v2 = self.target_net(img_v2)[0] + + loss_1 = self.head.loss(proj_online_v1, proj_target_v2) + loss_2 = self.head.loss(proj_online_v2, proj_target_v1) + + losses = dict(loss=2. * (loss_1 + loss_2)) + return losses diff --git a/mmpretrain/models/selfsup/cae.py b/mmpretrain/models/selfsup/cae.py new file mode 100644 index 0000000..67ac091 --- /dev/null +++ b/mmpretrain/models/selfsup/cae.py @@ -0,0 +1,472 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Part of code is modified from BEiT +# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py +import math +from collections import OrderedDict +from functools import partial +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models.backbones import BEiTViT +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +class Conv2d(nn.Module): + """Rewrite Conv2d module according to DALL-E code.""" + + def __init__(self, + n_in: int, + n_out: int, + kw: int, + use_float16: bool = True, + device: torch.device = torch.device('cpu'), + requires_grad: bool = False) -> None: + super().__init__() + + w = torch.empty((n_out, n_in, kw, kw), + dtype=torch.float32, + device=device, + requires_grad=requires_grad) + w.normal_(std=1 / math.sqrt(n_in * kw**2)) + + b = torch.zeros((n_out, ), + dtype=torch.float32, + device=device, + requires_grad=requires_grad) + self.kw = kw + self.w, self.b = nn.Parameter(w), nn.Parameter(b) + self.use_float16 = use_float16 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_float16 and 'cuda' in self.w.device.type: + if x.dtype != torch.float16: + x = x.half() + + w, b = self.w.half(), self.b.half() + else: + if x.dtype != torch.float32: + x = x.float() + + w, b = self.w, self.b + + return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) + + +class EncoderBlock(nn.Module): + """Rewrite EncoderBlock module according to DALL-E code.""" + + def __init__(self, + n_in: int, + n_out: int, + n_layers: int, + device: torch.device = None, + requires_grad: bool = False) -> None: + super().__init__() + self.n_hid = n_out // 4 + self.post_gain = 1 / (n_layers**2) + + make_conv = partial(Conv2d, device=device, requires_grad=requires_grad) + self.id_path = make_conv(n_in, n_out, + 1) if n_in != n_out else nn.Identity() + self.res_path = nn.Sequential( + OrderedDict([ + ('relu_1', nn.ReLU()), + ('conv_1', make_conv(n_in, self.n_hid, 3)), + ('relu_2', nn.ReLU()), + ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_3', nn.ReLU()), + ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_4', nn.ReLU()), + ('conv_4', make_conv(self.n_hid, n_out, 1)), + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.id_path(x) + self.post_gain * self.res_path(x) + + +@MODELS.register_module(name='DALL-E') +class DALLEEncoder(BaseModule): + """DALL-E Encoder for feature extraction. + + Args: + group_count (int): Number of groups in DALL-E encoder. Defaults to 4. + n_hid (int): Dimension of hidden layers. Defaults to 256. + n_blk_per_group (int): Number of blocks per group. Defaults to 2. + input_channels: (int): The channels of input images. Defaults to 3. + vocab_size (int): Vocabulary size, indicating the number of classes. + Defaults to 8192. + device (torch.device): Device of parameters. Defaults to + ``torch.device('cpu')``. + requires_grad (bool): Require gradient or not. Defaults to False. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + group_count: int = 4, + n_hid: int = 256, + n_blk_per_group: int = 2, + input_channels: int = 3, + vocab_size: int = 8192, + device: torch.device = torch.device('cpu'), + requires_grad: bool = False, + init_cfg: Union[dict, List[dict], None] = None): + super().__init__(init_cfg=init_cfg) + self.input_channels = input_channels + + blk_range = range(n_blk_per_group) + n_layers = group_count * n_blk_per_group + make_conv = partial(Conv2d, device=device, requires_grad=requires_grad) + make_blk = partial( + EncoderBlock, + n_layers=n_layers, + device=device, + requires_grad=requires_grad) + + self.blocks = nn.Sequential( + OrderedDict([ + ('input', make_conv(input_channels, 1 * n_hid, 7)), + ('group_1', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', make_blk(1 * n_hid, 1 * n_hid)) + for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_2', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk(1 * n_hid if i == 0 else 2 * n_hid, + 2 * n_hid)) for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_3', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk(2 * n_hid if i == 0 else 4 * n_hid, + 4 * n_hid)) for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_4', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk(4 * n_hid if i == 0 else 8 * n_hid, + 8 * n_hid)) for i in blk_range], + ]))), + ('output', + nn.Sequential( + OrderedDict([ + ('relu', nn.ReLU()), + ('conv', + make_conv( + 8 * n_hid, vocab_size, 1, use_float16=False)), + ]))), + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function of DALL-E encoder. + + Args: + x (torch.Tensor): The input images with shape (B, C, H, W). + + Returns: + torch.Tensor: The output with shape (B, vocab_size, h, w). + """ + x = x.float() + if len(x.shape) != 4: + raise ValueError(f'input shape {x.shape} is not 4d') + if x.shape[1] != self.input_channels: + raise ValueError(f'input has {x.shape[1]} channels but model \ + built for {self.input_channels}') + if x.dtype != torch.float32: + raise ValueError('input must have dtype torch.float32') + + return self.blocks(x) + + +@MODELS.register_module() +class CAEPretrainViT(BEiTViT): + """Vision Transformer for CAE pre-training and the implementation is based + on BEiTViT. + + Args: + arch (str | dict): Vision Transformer architecture. Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + layer_scale_init_value (float, optional): The init value of gamma in + BEiTTransformerEncoderLayer. + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + arch: str = 'b', + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + out_indices: int = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + bias: bool = 'qv_bias', + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + frozen_stages: int = -1, + use_abs_pos_emb: bool = True, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = False, + layer_scale_init_value: float = None, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: dict = [ + dict(type='Constant', val=1, layer=['LayerNorm']), + dict(type='TruncNormal', std=0.02, layer=['Conv2d']), + dict(type='Xavier', distribution='uniform', layer=['Linear']) + ] + ) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + bias=bias, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + frozen_stages=frozen_stages, + use_abs_pos_emb=use_abs_pos_emb, + use_rel_pos_bias=use_rel_pos_bias, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + layer_scale_init_value=layer_scale_init_value, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + self.pos_embed.requires_grad = False + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # initialize position embedding in backbone + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=True) + self.pos_embed.data.copy_(pos_embed.float()) + + trunc_normal_(self.cls_token, std=.02) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> torch.Tensor: + """Generate features for masked images. + + This function generates mask images and get the hidden features for + visible patches. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (torch.Tensor, optional): Mask for input, which is of shape + B x L. + + Returns: + torch.Tensor: hidden features. + """ + if mask is None: + return super().forward(x) + + else: + x, _ = self.patch_embed(x) + batch_size, _, dim = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + + # NOTE: unmasked embeddings + x_unmasked = x[~mask].reshape(batch_size, -1, dim) + x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1) + + pos_embed = self.pos_embed.expand(batch_size, self.num_patches + 1, + dim) + pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape( + batch_size, -1, dim) + pos_embed_unmasked = torch.cat( + (pos_embed[:, :1], pos_embed_unmasked), dim=1) + x_unmasked = x_unmasked + pos_embed_unmasked + + x_unmasked = self.drop_after_pos(x_unmasked) + + for i, layer in enumerate(self.layers): + x_unmasked = layer(x=x_unmasked, rel_pos_bias=None) + + if i == len(self.layers) - 1 and self.final_norm: + x_unmasked = self.norm1(x_unmasked) + + return x_unmasked + + +@MODELS.register_module() +class CAE(BaseSelfSupervisor): + """CAE. + + Implementation of `Context Autoencoder for Self-Supervised Representation + Learning `_. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of neck. + head (dict): Config dict for module of head functions. + target_generator: (dict, optional): The target_generator module to + generate targets for self-supervised learning optimization, such as + HOG, extracted features from other modules(DALL-E, CLIP), etc. + base_momentum (float): The base momentum coefficient for the target + network. Defaults to 0.0. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + target_generator: Optional[dict] = None, + base_momentum: float = 0.0, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + target_generator=target_generator, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + self.momentum = base_momentum + self.teacher = MODELS.build(backbone) + + def init_weights(self) -> None: + """Initialize weights.""" + super().init_weights() + + # init the weights of teacher with those of backbone + for param_backbone, param_teacher in zip(self.backbone.parameters(), + self.teacher.parameters()): + param_teacher.detach() + param_teacher.data.copy_(param_backbone.data) + param_teacher.requires_grad = False + + def momentum_update(self) -> None: + """Momentum update of the teacher network.""" + for param_bacbone, param_teacher in zip(self.backbone.parameters(), + self.teacher.parameters()): + param_teacher.data = param_teacher.data * self.momentum + \ + param_bacbone.data * (1. - self.momentum) + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + mask = mask.flatten(1).to(torch.bool) + + unmasked = self.backbone(inputs[0], mask) + + # get the latent prediction for the masked patches + with torch.no_grad(): + # inputs[0] is the prediction image + latent_target = self.teacher(inputs[0], ~mask) + latent_target = latent_target[:, 1:, :] + self.momentum_update() + + pos_embed = self.backbone.pos_embed.expand(inputs[0].shape[0], -1, -1) + pos_embed_masked = pos_embed[:, + 1:][mask].reshape(inputs[0].shape[0], -1, + pos_embed.shape[-1]) + pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape( + inputs[0].shape[0], -1, pos_embed.shape[-1]) + + # input the unmasked tokens and masked tokens to the decoder + logits, latent_pred = self.neck(unmasked[:, 1:], pos_embed_masked, + pos_embed_unmasked) + + logits = logits.view(-1, logits.shape[-1]) + # inputs[1] is the target image + logits_target = self.target_generator(inputs[1]) + loss_main, loss_align = self.head.loss(logits, logits_target, + latent_pred, latent_target, + mask) + losses = dict() + + losses['loss'] = loss_main + loss_align + losses['main'] = loss_main + losses['align'] = loss_align + return losses diff --git a/mmpretrain/models/selfsup/densecl.py b/mmpretrain/models/selfsup/densecl.py new file mode 100644 index 0000000..c969af1 --- /dev/null +++ b/mmpretrain/models/selfsup/densecl.py @@ -0,0 +1,203 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.dist import all_gather +from mmengine.model import ExponentialMovingAverage + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import batch_shuffle_ddp, batch_unshuffle_ddp +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class DenseCL(BaseSelfSupervisor): + """DenseCL. + + Implementation of `Dense Contrastive Learning for Self-Supervised Visual + Pre-Training `_. + Borrowed from the authors' code: ``_. + The loss_lambda warmup is in `engine/hooks/densecl_hook.py`. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of deep features to compact + feature vectors. + head (dict): Config dict for module of head functions. + queue_len (int): Number of negative keys maintained in the queue. + Defaults to 65536. + feat_dim (int): Dimension of compact feature vectors. Defaults to 128. + momentum (float): Momentum coefficient for the momentum-updated + encoder. Defaults to 0.999. + loss_lambda (float): Loss weight for the single and dense contrastive + loss. Defaults to 0.5. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + queue_len: int = 65536, + feat_dim: int = 128, + momentum: float = 0.001, + loss_lambda: float = 0.5, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.encoder_k = ExponentialMovingAverage( + nn.Sequential(self.backbone, self.neck), momentum) + + self.queue_len = queue_len + self.loss_lambda = loss_lambda + + # create the queue + self.register_buffer('queue', torch.randn(feat_dim, queue_len)) + self.queue = nn.functional.normalize(self.queue, dim=0) + self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) + + # create the second queue for dense output + self.register_buffer('queue2', torch.randn(feat_dim, queue_len)) + self.queue2 = nn.functional.normalize(self.queue2, dim=0) + self.register_buffer('queue2_ptr', torch.zeros(1, dtype=torch.long)) + + @torch.no_grad() + def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None: + """Update queue.""" + # gather keys before updating queue + keys = torch.cat(all_gather(keys), dim=0) + + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_len % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) + ptr = (ptr + batch_size) % self.queue_len # move pointer + + self.queue_ptr[0] = ptr + + @torch.no_grad() + def _dequeue_and_enqueue2(self, keys: torch.Tensor) -> None: + """Update queue2.""" + # gather keys before updating queue + keys = torch.cat(all_gather(keys), dim=0) + + batch_size = keys.shape[0] + + ptr = int(self.queue2_ptr) + assert self.queue_len % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue2[:, ptr:ptr + batch_size] = keys.transpose(0, 1) + ptr = (ptr + batch_size) % self.queue_len # move pointer + + self.queue2_ptr[0] = ptr + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + im_q = inputs[0] + im_k = inputs[1] + # compute query features + q_b = self.backbone(im_q) # backbone features + q, q_grid, q2 = self.neck(q_b) # queries: NxC; NxCxS^2 + q_b = q_b[0] + q_b = q_b.view(q_b.size(0), q_b.size(1), -1) + + q = nn.functional.normalize(q, dim=1) + q2 = nn.functional.normalize(q2, dim=1) + q_grid = nn.functional.normalize(q_grid, dim=1) + q_b = nn.functional.normalize(q_b, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + # update the key encoder + self.encoder_k.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + # shuffle for making use of BN + im_k, idx_unshuffle = batch_shuffle_ddp(im_k) + + k_b = self.encoder_k.module[0](im_k) # backbone features + k, k_grid, k2 = self.encoder_k.module[1](k_b) # keys: NxC; NxCxS^2 + k_b = k_b[0] + k_b = k_b.view(k_b.size(0), k_b.size(1), -1) + + k = nn.functional.normalize(k, dim=1) + k2 = nn.functional.normalize(k2, dim=1) + k_grid = nn.functional.normalize(k_grid, dim=1) + k_b = nn.functional.normalize(k_b, dim=1) + + # undo shuffle + k = batch_unshuffle_ddp(k, idx_unshuffle) + k2 = batch_unshuffle_ddp(k2, idx_unshuffle) + k_grid = batch_unshuffle_ddp(k_grid, idx_unshuffle) + k_b = batch_unshuffle_ddp(k_b, idx_unshuffle) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) + # negative logits: NxK + l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + + # feat point set sim + backbone_sim_matrix = torch.matmul(q_b.permute(0, 2, 1), k_b) + densecl_sim_ind = backbone_sim_matrix.max(dim=2)[1] # NxS^2 + + indexed_k_grid = torch.gather(k_grid, 2, + densecl_sim_ind.unsqueeze(1).expand( + -1, k_grid.size(1), -1)) # NxCxS^2 + densecl_sim_q = (q_grid * indexed_k_grid).sum(1) # NxS^2 + + # dense positive logits: NS^2X1 + l_pos_dense = densecl_sim_q.view(-1).unsqueeze(-1) + + q_grid = q_grid.permute(0, 2, 1) + q_grid = q_grid.reshape(-1, q_grid.size(2)) + # dense negative logits: NS^2xK + l_neg_dense = torch.einsum( + 'nc,ck->nk', [q_grid, self.queue2.clone().detach()]) + + loss_single = self.head.loss(l_pos, l_neg) + loss_dense = self.head.loss(l_pos_dense, l_neg_dense) + + losses = dict() + losses['loss_single'] = loss_single * (1 - self.loss_lambda) + losses['loss_dense'] = loss_dense * self.loss_lambda + + self._dequeue_and_enqueue(k) + self._dequeue_and_enqueue2(k2) + + return losses diff --git a/mmpretrain/models/selfsup/eva.py b/mmpretrain/models/selfsup/eva.py new file mode 100644 index 0000000..30779be --- /dev/null +++ b/mmpretrain/models/selfsup/eva.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class EVA(BaseSelfSupervisor): + """EVA. + + Implementation of `EVA: Exploring the Limits of Masked Visual + Representation Learning at Scale `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + clip_feature, _ = self.target_generator(inputs) + + latent, mask, ids_restore = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + + clip_feature = clip_feature[:, 1:, :] + loss = self.head.loss(pred, clip_feature, mask) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/itpn.py b/mmpretrain/models/selfsup/itpn.py new file mode 100644 index 0000000..488a996 --- /dev/null +++ b/mmpretrain/models/selfsup/itpn.py @@ -0,0 +1,359 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models.backbones.hivit import BlockWithRPE, HiViT, PatchMerge +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class iTPNHiViT(HiViT): + """HiViT for iTPN pre-training. + + Args: + img_size (int | tuple): Input image size. Defaults to 224. + patch_size (int | tuple): The patch size. Defaults to 16. + inner_patches (int): Inner patch. Defaults to 4. + stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim + in the first two stages. Defaults to 3. + mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in + the last stage. Defaults to 4. + qkv_bias (bool): Enable bias for qkv projections if True. + qk_scale (float): The number of divider after q@k. Default to None. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + ape (bool): If True, add absolute position embedding to + the patch embedding. + rpe (bool): If True, add relative position embedding to + the patch embedding. + layer_scale_init_value (float): Layer-scale init values. Defaults to 0. + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + reconstruction_type (str): The reconstruction of self-supervised + learning. Defaults to 'pixel'. + """ + + def __init__( + self, + arch='base', + img_size: int = 224, + patch_size: int = 16, + inner_patches: int = 4, + stem_mlp_ratio: int = 3., + mlp_ratio: int = 4., + qkv_bias: bool = True, + qk_scale: Optional[bool] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + ape: bool = True, + rpe: bool = False, + layer_scale_init_value: float = 0.0, + mask_ratio: float = 0.75, + reconstruction_type: str = 'pixel', + **kwargs, + ): + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + inner_patches=inner_patches, + stem_mlp_ratio=stem_mlp_ratio, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + ape=ape, + rpe=rpe, + layer_scale_init_value=layer_scale_init_value, + **kwargs, + ) + + self.pos_embed.requires_grad = False + self.mask_ratio = mask_ratio + + assert reconstruction_type in ['pixel', 'clip'], \ + 'iTPN method only support `pixel` and `clip`, ' \ + f'but got `{reconstruction_type}`.' + self.reconstruction_type = reconstruction_type + self.num_patches = self.patch_embed.num_patches + + if reconstruction_type == 'clip': + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().apply(self._init_weights) + + if self.reconstruction_type == 'clip': + trunc_normal_(self.mask_token, std=0.02) + self.rescale_init_weight() + else: + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=False) + self.pos_embed.data.copy_(pos_embed.float()) + + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def rescale_init_weight(self) -> None: + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + if isinstance(layer, BlockWithRPE): + if layer.attn is not None: + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def masking_id(self, batch_size, mask_ratio): + N, L = batch_size, self.pos_embed.size(1) + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand( + N, L, device=self.pos_embed.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=self.pos_embed.device) + mask[:, :ids_keep.size(1)] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return ids_keep, ids_restore, mask + + def forward_pixel( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[Tuple, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B, C, H, W = x.shape + ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio) + + x = self.patch_embed(x) + + x = torch.gather( + x, + dim=1, + index=ids_keep[:, :, None, None, + None].expand(-1, -1, *x.shape[2:])) + + outs = [] + for blk in self.blocks[:-self.num_main_blocks]: + if isinstance(blk, PatchMerge): + outs.append(x) + x = blk(x) + + x = x[..., 0, 0, :] + if self.ape: + pos_embed = self.interpolate_pos_encoding(x, H, W) + pos_embed = torch.gather( + pos_embed.expand(B, -1, -1), + dim=1, + index=ids_keep[:, :, None].expand(-1, -1, + pos_embed.shape[2]), + ) + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks[-self.num_main_blocks:]: + x = blk(x) + + outs.append(x) + + return (tuple(outs), mask, ids_restore) + + def forward_clip(self, + x: torch.Tensor, + mask: Optional[bool] = True) -> Tuple: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B, C, H, W = x.shape + x = self.patch_embed(x) + + outs = [] + for blk in self.blocks[:-self.num_main_blocks]: + if isinstance(blk, PatchMerge): + outs.append(x) + x = blk(x) + + x = x[..., 0, 0, :] + B, L, _ = x.shape + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + if self.ape: + pos_embed = self.interpolate_pos_encoding(x, H, W) + x = x + pos_embed + x = self.pos_drop(x) + + rpe_index = True if self.rpe else None + + for blk in self.blocks[-self.num_main_blocks:]: + x = blk(x, rpe_index) + + outs.append(x) + + return tuple(outs) + + def forward(self, x: torch.Tensor, mask: Optional[bool] = True) -> Tuple: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + + if self.reconstruction_type == 'pixel': + return self.forward_pixel(x, mask) + return self.forward_clip(x, mask) + + +@MODELS.register_module() +class iTPN(BaseSelfSupervisor): + """iTPN. + + Implementation of `iTPN: Integrally Pre-Trained Transformer Pyramid + Networks `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + if self.backbone.reconstruction_type == 'pixel': + latent, mask, ids_restore = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + + loss = self.head.loss(pred, inputs, mask) + else: + mask = torch.stack( + [data_sample.mask for data_sample in data_samples]) + + img_latent = self.backbone(inputs[0], mask) + + # inputs[1] is the target image + with torch.no_grad(): + target = self.target_generator(inputs[1])[0] + target = target.detach() + + # iTPN contains a neck module + feats = self.neck(img_latent) + loss = self.head.loss(feats, target[:, 1:, :], mask) + + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/mae.py b/mmpretrain/models/selfsup/mae.py new file mode 100644 index 0000000..01bc5bc --- /dev/null +++ b/mmpretrain/models/selfsup/mae.py @@ -0,0 +1,416 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch + +from mmpretrain.models import HiViT, VisionTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MAEViT(VisionTransformer): + """Vision Transformer for MAE pre-training. + + A PyTorch implement of: `An Image is Worth 16x16 Words: Transformers + for Image Recognition at Scale `_. + This module implements the patch masking in MAE and initialize the + position embedding with sine-cosine position embedding. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: Union[Sequence, int] = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + mask_ratio: float = 0.75, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + # position embedding is not learnable during pretraining + self.pos_embed.requires_grad = False + self.mask_ratio = mask_ratio + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=True) + self.pos_embed.data.copy_(pos_embed.float()) + + w = self.patch_embed.projection.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + torch.nn.init.normal_(self.cls_token, std=.02) + + def random_masking( + self, + x: torch.Tensor, + mask_ratio: float = 0.75 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the mask for MAE Pre-training. + + Args: + x (torch.Tensor): Image with data augmentation applied, which is + of shape B x L x C. + mask_ratio (float): The mask ratio of total patches. + Defaults to 0.75. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: masked image, mask + and the ids to restore original image. + + - ``x_masked`` (torch.Tensor): masked image. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather( + x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def forward( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, self.mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for _, layer in enumerate(self.layers): + x = layer(x) + # Use final norm + x = self.norm1(x) + + return (x, mask, ids_restore) + + +@MODELS.register_module() +class MAE(BaseSelfSupervisor): + """MAE. + + Implementation of `Masked Autoencoders Are Scalable Vision Learners + `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + # ids_restore: the same as that in original repo, which is used + # to recover the original order of tokens in decoder. + latent, mask, ids_restore = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + loss = self.head.loss(pred, inputs, mask) + losses = dict(loss=loss) + return losses + + +@MODELS.register_module() +class MAEHiViT(HiViT): + """HiViT for MAE pre-training. + + A PyTorch implement of: `HiViT: A Simple and More Efficient Design + of Hierarchical Vision Transformer `_. + This module implements the patch masking in MAE and initialize the + position embedding with sine-cosine position embedding. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + Defaults to 4, to downsample 4x at the first stage + inner_patches (int): The inner patches within a token + Defaults to 4 + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + ape (bool): the absolute position embedding + rpe (bool): the relative position embedding + Defaults to False + layer_scale_init_value (float): the layer scale init value + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + inner_patches: int = 4, + out_indices: Union[list, int] = [23], + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + ape: bool = True, + rpe: bool = False, + layer_scale_init_value: float = 0.0, + mask_ratio: float = 0.75, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + inner_patches=inner_patches, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + ape=ape, + rpe=rpe, + layer_scale_init_value=layer_scale_init_value, + init_cfg=init_cfg) + + self.pos_embed.requires_grad = False + self.mask_ratio = mask_ratio + self.num_patches = self.patch_embed.num_patches + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding.""" + super().apply(self._init_weights) + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=False) + self.pos_embed.data.copy_(pos_embed.float()) + + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def masking_id( + self, batch_size, + mask_ratio) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the mask for MAE Pre-training. + + Args: + batch_size: The batch size of input data + mask_ratio: The mask ratio of total patches. + Defaults to 0.75. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: the ids + for the tokens retained, the ids to restore original image, + and the mask + """ + N, L = batch_size, self.pos_embed.size(1) + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand( + N, L, device=self.pos_embed.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=self.pos_embed.device) + mask[:, :ids_keep.size(1)] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return ids_keep, ids_restore, mask + + def forward( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B, C, H, W = x.shape + ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio) + + x = self.patch_embed(x) + + x = torch.gather( + x, + dim=1, + index=ids_keep[:, :, None, None, + None].expand(-1, -1, *x.shape[2:])) + + for blk in self.blocks[:-self.num_main_blocks]: + x = blk(x) + + x = x[..., 0, 0, :] + if self.ape: + pos_embed = self.interpolate_pos_encoding(x, H, W) + pos_embed = torch.gather( + pos_embed.expand(B, -1, -1), + dim=1, + index=ids_keep[:, :, None].expand(-1, -1, + pos_embed.shape[2]), + ) + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks[-self.num_main_blocks:]: + x = blk(x) + + return (x, mask, ids_restore) diff --git a/mmpretrain/models/selfsup/maskfeat.py b/mmpretrain/models/selfsup/maskfeat.py new file mode 100644 index 0000000..fd9f0b2 --- /dev/null +++ b/mmpretrain/models/selfsup/maskfeat.py @@ -0,0 +1,336 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Sequence, Union + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.models import VisionTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class HOGGenerator(BaseModule): + """Generate HOG feature for images. + + This module is used in MaskFeat to generate HOG feature. The code is + modified from file `slowfast/models/operators.py + `_. + Here is the link of `HOG wikipedia + `_. + + Args: + nbins (int): Number of bin. Defaults to 9. + pool (float): Number of cell. Defaults to 8. + gaussian_window (int): Size of gaussian kernel. Defaults to 16. + """ + + def __init__(self, + nbins: int = 9, + pool: int = 8, + gaussian_window: int = 16) -> None: + super().__init__() + self.nbins = nbins + self.pool = pool + self.pi = math.pi + weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) + weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1).contiguous() + weight_y = weight_x.transpose(2, 3).contiguous() + self.register_buffer('weight_x', weight_x) + self.register_buffer('weight_y', weight_y) + + self.gaussian_window = gaussian_window + if gaussian_window: + gaussian_kernel = self.get_gaussian_kernel(gaussian_window, + gaussian_window // 2) + self.register_buffer('gaussian_kernel', gaussian_kernel) + + def get_gaussian_kernel(self, kernlen: int, std: int) -> torch.Tensor: + """Returns a 2D Gaussian kernel array.""" + + def _gaussian_fn(kernlen: int, std: int) -> torch.Tensor: + n = torch.arange(0, kernlen).float() + n -= n.mean() + n /= std + w = torch.exp(-0.5 * n**2) + return w + + kernel_1d = _gaussian_fn(kernlen, std) + kernel_2d = kernel_1d[:, None] * kernel_1d[None, :] + return kernel_2d / kernel_2d.sum() + + def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor: + """Reshape HOG Features for output.""" + hog_feat = hog_feat.flatten(1, 2) + self.unfold_size = hog_feat.shape[-1] // 14 + hog_feat = hog_feat.permute(0, 2, 3, 1) + hog_feat = hog_feat.unfold(1, self.unfold_size, + self.unfold_size).unfold( + 2, self.unfold_size, self.unfold_size) + hog_feat = hog_feat.flatten(1, 2).flatten(2) + return hog_feat + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Generate hog feature for each batch images. + + Args: + x (torch.Tensor): Input images of shape (N, 3, H, W). + + Returns: + torch.Tensor: Hog features. + """ + # input is RGB image with shape [B 3 H W] + self.h, self.w = x.size(-2), x.size(-1) + x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect') + gx_rgb = F.conv2d( + x, self.weight_x, bias=None, stride=1, padding=0, groups=3) + gy_rgb = F.conv2d( + x, self.weight_y, bias=None, stride=1, padding=0, groups=3) + norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1) + phase = torch.atan2(gx_rgb, gy_rgb) + phase = phase / self.pi * self.nbins # [-9, 9] + + b, c, h, w = norm_rgb.shape + out = torch.zeros((b, c, self.nbins, h, w), + dtype=torch.float, + device=x.device) + phase = phase.view(b, c, 1, h, w) + norm_rgb = norm_rgb.view(b, c, 1, h, w) + if self.gaussian_window: + if h != self.gaussian_window: + assert h % self.gaussian_window == 0, 'h {} gw {}'.format( + h, self.gaussian_window) + repeat_rate = h // self.gaussian_window + temp_gaussian_kernel = self.gaussian_kernel.repeat( + [repeat_rate, repeat_rate]) + else: + temp_gaussian_kernel = self.gaussian_kernel + norm_rgb *= temp_gaussian_kernel + + out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb) + + out = out.unfold(3, self.pool, self.pool) + out = out.unfold(4, self.pool, self.pool) + out = out.sum(dim=[-1, -2]) + + self.out = F.normalize(out, p=2, dim=2) + + return self._reshape(self.out) + + def generate_hog_image(self, hog_out: torch.Tensor) -> np.ndarray: + """Generate HOG image according to HOG features.""" + assert hog_out.size(0) == 1 and hog_out.size(1) == 3, \ + 'Check the input batch size and the channcel number, only support'\ + '"batch_size = 1".' + hog_image = np.zeros([self.h, self.w]) + cell_gradient = np.array(hog_out.mean(dim=1).squeeze().detach().cpu()) + cell_width = self.pool / 2 + max_mag = np.array(cell_gradient).max() + angle_gap = 360 / self.nbins + + for x in range(cell_gradient.shape[1]): + for y in range(cell_gradient.shape[2]): + cell_grad = cell_gradient[:, x, y] + cell_grad /= max_mag + angle = 0 + for magnitude in cell_grad: + angle_radian = math.radians(angle) + x1 = int(x * self.pool + + magnitude * cell_width * math.cos(angle_radian)) + y1 = int(y * self.pool + + magnitude * cell_width * math.sin(angle_radian)) + x2 = int(x * self.pool - + magnitude * cell_width * math.cos(angle_radian)) + y2 = int(y * self.pool - + magnitude * cell_width * math.sin(angle_radian)) + magnitude = 0 if magnitude < 0 else magnitude + cv2.line(hog_image, (y1, x1), (y2, x2), + int(255 * math.sqrt(magnitude))) + angle += angle_gap + return hog_image + + +@MODELS.register_module() +class MaskFeatViT(VisionTransformer): + """Vision Transformer for MaskFeat pre-training. + + A PyTorch implement of: `Masked Feature Prediction for Self-Supervised + Visual Pre-Training `_. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: Union[Sequence, int] = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + with_cls_token=True, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + self.mask_token = nn.parameter.Parameter( + torch.zeros(1, 1, self.embed_dims), requires_grad=True) + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + """Initialize position embedding, mask token and cls token.""" + super().init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + + nn.init.trunc_normal_(self.cls_token, std=.02) + nn.init.trunc_normal_(self.mask_token, std=.02) + nn.init.trunc_normal_(self.pos_embed, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m: torch.nn.Module) -> None: + if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> torch.Tensor: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images. + mask (torch.Tensor, optional): Input masks. + + Returns: + torch.Tensor: Features with cls_tokens. + """ + if mask is None: + return super().forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + + # masking: length -> length * mask_ratio + B, L, _ = x.shape + mask_tokens = self.mask_token.expand(B, L, -1) + mask = mask.unsqueeze(-1) + x = x * (1 - mask.int()) + mask_tokens * mask + + # append cls token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.drop_after_pos(x) + + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + return x + + +@MODELS.register_module() +class MaskFeat(BaseSelfSupervisor): + """MaskFeat. + + Implementation of `Masked Feature Prediction for Self-Supervised Visual + Pre-Training `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + mask = mask.flatten(1).bool() + + latent = self.backbone(inputs, mask) + B, L, C = latent.shape + pred = self.neck((latent.view(B * L, C), )) + pred = pred[0].view(B, L, -1) + hog = self.target_generator(inputs) + + # remove cls_token before compute loss + loss = self.head.loss(pred[:, 1:], hog, mask) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/mff.py b/mmpretrain/models/selfsup/mff.py new file mode 100644 index 0000000..2685058 --- /dev/null +++ b/mmpretrain/models/selfsup/mff.py @@ -0,0 +1,194 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F + +from mmpretrain.models.selfsup.mae import MAE, MAEViT +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class MFFViT(MAEViT): + """Vision Transformer for MFF Pretraining. + + This class inherits all these functionalities from ``MAEViT``, and + add multi-level feature fusion to it. For more details, you can + refer to `Improving Pixel-based MIM by Reducing Wasted Modeling + Capability`. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: Union[Sequence, int] = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + mask_ratio: float = 0.75, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + mask_ratio=mask_ratio, + init_cfg=init_cfg) + proj_layers = [ + torch.nn.Linear(self.embed_dims, self.embed_dims) + for _ in range(len(self.out_indices) - 1) + ] + self.proj_layers = torch.nn.ModuleList(proj_layers) + self.proj_weights = torch.nn.Parameter( + torch.ones(len(self.out_indices)).view(-1, 1, 1, 1)) + if len(self.out_indices) == 1: + self.proj_weights.requires_grad = False + + def forward( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, self.mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + res = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + if i != self.out_indices[-1]: + proj_x = self.proj_layers[self.out_indices.index(i)](x) + else: + proj_x = x + res.append(proj_x) + res = torch.stack(res) + proj_weights = F.softmax(self.proj_weights, dim=0) + res = res * proj_weights + res = res.sum(dim=0) + + # Use final norm + x = self.norm1(res) + return (x, mask, ids_restore, proj_weights.view(-1)) + + +@MODELS.register_module() +class MFF(MAE): + """MFF. + + Implementation of `Improving Pixel-based MIM by Reducing Wasted Modeling + Capability`. + """ + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + # ids_restore: the same as that in original repo, which is used + # to recover the original order of tokens in decoder. + latent, mask, ids_restore, weights = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + loss = self.head.loss(pred, inputs, mask) + weight_params = { + f'weight_{i}': weights[i] + for i in range(weights.size(0)) + } + losses = dict(loss=loss) + losses.update(weight_params) + return losses diff --git a/mmpretrain/models/selfsup/milan.py b/mmpretrain/models/selfsup/milan.py new file mode 100644 index 0000000..fdf8673 --- /dev/null +++ b/mmpretrain/models/selfsup/milan.py @@ -0,0 +1,202 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from mmengine.runner.checkpoint import _load_checkpoint + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_clip_model +from .base import BaseSelfSupervisor +from .mae import MAEViT + + +@MODELS.register_module() +class CLIPGenerator(nn.Module): + """Get the features and attention from the last layer of CLIP. + + This module is used to generate target features in masked image modeling. + + Args: + tokenizer_path (str): The path of the checkpoint of CLIP. + """ + + def __init__(self, tokenizer_path: str) -> None: + super().__init__() + self.tokenizer_path = tokenizer_path + self.tokenizer = build_clip_model( + _load_checkpoint(self.tokenizer_path), False) + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the features and attention from the last layer of CLIP. + + Args: + x (torch.Tensor): The input image, which is of shape (N, 3, H, W). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The features and attention from + the last layer of CLIP, which are of shape (N, L, C) and (N, L, L), + respectively. + """ + # use the visual branch of CLIP to get the features + assert self.tokenizer is not None, 'Please check whether the ' \ + '`self.tokenizer` is initialized correctly.' + + clip_features = self.tokenizer.encode_image(x) + return clip_features + + +@MODELS.register_module() +class MILANViT(MAEViT): + """Vision Transformer for MILAN pre-training. + + Implementation of the encoder for `MILAN: Masked Image Pretraining on + Language Assisted Representation `_. + + This module inherits from MAEViT and only overrides the forward function + and replace random masking with attention masking. + """ + + def attention_masking( + self, x: torch.Tensor, mask_ratio: float, importance: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate attention mask for MILAN. + + This is what is different from MAEViT, which uses random masking. + Attention masking generates attention mask for MILAN, according to + importance. The higher the importance, the more likely the patch is + kept. + + Args: + x (torch.Tensor): Input images, which is of shape B x L x C. + mask_ratio (float): The ratio of patches to be masked. + importance (torch.Tensor): Importance of each patch, which is of + shape B x L. + + Returns: + Tuple[torch.Tensor, ...]: + + - ``x_masked``: masked image + - ``ids_restore``: the ids to restore original image + - ``ids_keep``: ids of the kept patches + - ``ids_dump``: ids of the removed patches + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = importance.to(x.device) # large is keep, small is remove + + # sort noise for each sample + ids_shuffle = torch.multinomial(noise, L, replacement=False) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + ids_dump = ids_shuffle[:, len_keep:] + x_masked = torch.gather( + x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, ids_restore, ids_keep, ids_dump + + def forward( + self, + x: torch.Tensor, + importance: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the + ``importance`` is ``None``, the function generates mask and masks some + patches randomly and get the hidden features for visible patches. The + mask is generated by importance. The higher the importance, the more + likely the patch is kept. The importance is calculated by CLIP. + The higher the CLIP score, the more likely the patch is kept. The CLIP + score is calculated by cross attention between the class token and all + other tokens from the last layer. + If the ``importance`` is ``torch.Tensor``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + importance (torch.Tensor, optional): Importance of each patch, + which is of shape B x L. + + Returns: + Tuple[torch.Tensor, ...]: masked image, the ids to restore original + image, ids of the kept patches, ids of the removed patches. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + - ``ids_keep`` (torch.Tensor): ids of the kept patches. + - ``ids_dump`` (torch.Tensor): ids of the removed patches. + """ + if importance is None: + return super(MAEViT, self).forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, ids_restore, ids_keep, ids_dump = self.attention_masking( + x, self.mask_ratio, importance) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for _, layer in enumerate(self.layers): + x = layer(x) + # Use final norm + x = self.norm1(x) + + return x, ids_restore, ids_keep, ids_dump + + +@MODELS.register_module() +class MILAN(BaseSelfSupervisor): + """MILAN. + + Implementation of `MILAN: Masked Image Pretraining on Language Assisted + Representation `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, importance=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + # ids_restore: the same as that in original repo, which is used + # to recover the original order of tokens in decoder. + clip_feature, importance = self.target_generator(inputs) + importance = importance[:, 0, 1:] + latent, ids_restore, ids_keep, ids_dump = self.backbone( + inputs, importance) + pred = self.neck(latent, ids_restore, ids_keep, ids_dump) + + loss = self.head.loss(pred, clip_feature) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/mixmim.py b/mmpretrain/models/selfsup/mixmim.py new file mode 100644 index 0000000..b202f83 --- /dev/null +++ b/mmpretrain/models/selfsup/mixmim.py @@ -0,0 +1,263 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from mmpretrain.models.backbones import MixMIMTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import build_2d_sincos_position_embedding +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MixMIMPretrainTransformer(MixMIMTransformer): + """MixMIM backbone for MixMIM pre-training. + + A PyTorch implement of : ` MixMIM: Mixed and Masked Image + Modeling for Efficient Visual Representation Learning + `_ + + Args: + arch (str | dict): MixMIM architecture. If use string, + choose from 'base','large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + + Defaults to 'base'. + mlp_ratio (int): The mlp ratio in FFN. Defaults to 4. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to mlp_ratio + the most common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + window_size (list): The height and width of the window. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + attn_drop_rate (float): Attention drop rate. Defaults to 0. + use_checkpoint (bool): Whether use the checkpoint to reduce GPU memory + cost. Defaults to False. + mask_ratio (bool): The base ratio of total number of patches to be + masked. Defaults to 0.5. + range_mask_ratio (float): The range of mask ratio. + Defaults to 0. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'base', + mlp_ratio: float = 4, + img_size: int = 224, + patch_size: int = 4, + in_channels: int = 3, + window_size: List = [14, 14, 14, 7], + qkv_bias: bool = True, + patch_cfg: dict = dict(), + norm_cfg: dict = dict(type='LN'), + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + attn_drop_rate: float = 0.0, + use_checkpoint: bool = False, + mask_ratio: float = 0.5, + range_mask_ratio: float = 0.0, + init_cfg: Optional[dict] = None) -> None: + + super().__init__( + arch=arch, + mlp_ratio=mlp_ratio, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + window_size=window_size, + qkv_bias=qkv_bias, + patch_cfg=patch_cfg, + norm_cfg=norm_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + attn_drop_rate=attn_drop_rate, + use_checkpoint=use_checkpoint, + init_cfg=init_cfg) + + self.mask_ratio = mask_ratio + self.range_mask_ratio = range_mask_ratio + + def init_weights(self): + """Initialize position embedding, patch embedding.""" + super(MixMIMTransformer, self).init_weights() + + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.absolute_pos_embed.shape[-1], + cls_token=False) + self.absolute_pos_embed.data.copy_(pos_embed.float()) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def random_masking(self, + x: torch.Tensor, + mask_ratio: float = 0.5) -> Tuple[torch.Tensor]: + """Generate the mask for MixMIM Pretraining. + + Args: + x (torch.Tensor): Image with data augmentation applied, which is + of shape B x L x C. + mask_ratio (float): The mask ratio of total patches. + Defaults to 0.5. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + - mask_s1 (torch.Tensor): mask with stride of + self.encoder_stride // 8. + - mask_s2 (torch.Tensor): mask with stride of + self.encoder_stride // 4. + - mask_s3 (torch.Tensor): mask with stride of + self.encoder_stride // 2. + - mask (torch.Tensor): mask with stride of + self.encoder_stride. + """ + + B, C, H, W = x.shape + out_H = H // self.encoder_stride + out_W = W // self.encoder_stride + s3_H, s3_W = out_H * 2, out_W * 2 + s2_H, s2_W = out_H * 4, out_W * 4 + s1_H, s1_W = out_H * 8, out_W * 8 + + seq_l = out_H * out_W + # use a shared mask for a batch images + mask = torch.zeros([1, 1, seq_l], device=x.device) + + mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio) + noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1] + # ascend: small is keep, large is removed + mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)] + mask.scatter_(2, mask_idx, 1) + mask = mask.reshape(1, 1, out_H, out_W) + mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest') + mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest') + mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest') + + mask = mask.reshape(1, out_H * out_W, 1).contiguous() + mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous() + mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous() + mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous() + + return mask_s1, mask_s2, mask_s3, mask + + def forward(self, + x: torch.Tensor, + mask: Optional[bool] = True) -> Tuple[torch.Tensor]: + """Generate features for masked images. + + This function generates mask and masks some patches randomly and get + the hidden features for visible patches. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward containing + ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - x (torch.Tensor): hidden features, which is of shape + B x L x C. + - mask_s4 (torch.Tensor): the mask tensor for the last layer. + """ + if mask is None or False: + return super().forward(x) + + else: + mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking( + x, self.mask_ratio) + + x, _ = self.patch_embed(x) + + x = x * (1. - mask_s1) + x.flip(0) * mask_s1 + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + for idx, layer in enumerate(self.layers): + if idx == 0: + x = layer(x, attn_mask=mask_s1) + elif idx == 1: + x = layer(x, attn_mask=mask_s2) + elif idx == 2: + x = layer(x, attn_mask=mask_s3) + elif idx == 3: + x = layer(x, attn_mask=mask_s4) + + x = self.norm(x) + + return x, mask_s4 + + +@MODELS.register_module() +class MixMIM(BaseSelfSupervisor): + """MixMIM. + + Implementation of `MixMIM: Mixed and Masked Image Modeling for Efficient + Visual Representation Learning. `_. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + pretrained: Optional[str] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): + + head.update(dict(patch_size=neck['encoder_stride'])) + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + latent, mask = self.backbone(inputs) + x_rec = self.neck(latent, mask) + loss = self.head.loss(x_rec, inputs, mask) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/moco.py b/mmpretrain/models/selfsup/moco.py new file mode 100644 index 0000000..7ff4cf8 --- /dev/null +++ b/mmpretrain/models/selfsup/moco.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.dist import all_gather +from mmengine.model import ExponentialMovingAverage + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import batch_shuffle_ddp, batch_unshuffle_ddp +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MoCo(BaseSelfSupervisor): + """MoCo. + + Implementation of `Momentum Contrast for Unsupervised Visual + Representation Learning `_. + Part of the code is borrowed from: + ``_. + + Args: + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of deep features to compact feature + vectors. + head (dict): Config dict for module of head functions. + queue_len (int): Number of negative keys maintained in the + queue. Defaults to 65536. + feat_dim (int): Dimension of compact feature vectors. + Defaults to 128. + momentum (float): Momentum coefficient for the momentum-updated + encoder. Defaults to 0.001. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + queue_len: int = 65536, + feat_dim: int = 128, + momentum: float = 0.001, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.encoder_k = ExponentialMovingAverage( + nn.Sequential(self.backbone, self.neck), momentum) + + # create the queue + self.queue_len = queue_len + self.register_buffer('queue', torch.randn(feat_dim, queue_len)) + self.queue = nn.functional.normalize(self.queue, dim=0) + self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) + + @torch.no_grad() + def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None: + """Update queue.""" + # gather keys before updating queue + keys = torch.cat(all_gather(keys), dim=0) + + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_len % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) + ptr = (ptr + batch_size) % self.queue_len # move pointer + + self.queue_ptr[0] = ptr + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + im_q = inputs[0] + im_k = inputs[1] + # compute query features from encoder_q + q = self.neck(self.backbone(im_q))[0] # queries: NxC + q = nn.functional.normalize(q, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + # update the key encoder + self.encoder_k.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + # shuffle for making use of BN + im_k, idx_unshuffle = batch_shuffle_ddp(im_k) + + k = self.encoder_k(im_k)[0] # keys: NxC + k = nn.functional.normalize(k, dim=1) + + # undo shuffle + k = batch_unshuffle_ddp(k, idx_unshuffle) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) + # negative logits: NxK + l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + + loss = self.head.loss(l_pos, l_neg) + # update the queue + self._dequeue_and_enqueue(k) + + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/mocov3.py b/mmpretrain/models/selfsup/mocov3.py new file mode 100644 index 0000000..61b8033 --- /dev/null +++ b/mmpretrain/models/selfsup/mocov3.py @@ -0,0 +1,215 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from functools import reduce +from operator import mul +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.backbones import VisionTransformer +from mmpretrain.models.utils import (build_2d_sincos_position_embedding, + to_2tuple) +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils import CosineEMA +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class MoCoV3ViT(VisionTransformer): + """Vision Transformer for MoCoV3 pre-training. + + A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for + Image Recognition at Scale `_. + + Part of the code is modified from: + ``_. + + Args: + stop_grad_conv1 (bool): whether to stop the gradient of + convolution layer in `PatchEmbed`. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + stop_grad_conv1: bool = False, + frozen_stages: int = -1, + norm_eval: bool = False, + init_cfg: Optional[Union[dict, List[dict]]] = None, + **kwargs) -> None: + + # add MoCoV3 ViT-small arch + self.arch_zoo.update( + dict.fromkeys( + ['mocov3-s', 'mocov3-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 1536, + })) + + super().__init__(init_cfg=init_cfg, **kwargs) + self.patch_size = kwargs['patch_size'] + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.init_cfg = init_cfg + + if stop_grad_conv1: + self.patch_embed.projection.weight.requires_grad = False + self.patch_embed.projection.bias.requires_grad = False + + self._freeze_stages() + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding, qkv layers and cls + token.""" + super().init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + + # Use fixed 2D sin-cos position embedding + pos_emb = build_2d_sincos_position_embedding( + patches_resolution=self.patch_resolution, + embed_dims=self.embed_dims, + cls_token=True) + self.pos_embed.data.copy_(pos_emb) + self.pos_embed.requires_grad = False + + # xavier_uniform initialization for PatchEmbed + val = math.sqrt( + 6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) + + self.embed_dims)) + nn.init.uniform_(self.patch_embed.projection.weight, -val, val) + nn.init.zeros_(self.patch_embed.projection.bias) + + # initialization for linear layers + for name, m in self.named_modules(): + if isinstance(m, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt( + 6. / + float(m.weight.shape[0] // 3 + m.weight.shape[1])) + nn.init.uniform_(m.weight, -val, val) + else: + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + nn.init.normal_(self.cls_token, std=1e-6) + + def _freeze_stages(self) -> None: + """Freeze patch_embed layer, some parameters and stages.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + self.cls_token.requires_grad = False + self.pos_embed.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if i == (self.num_layers) and self.final_norm: + for param in getattr(self, 'norm1').parameters(): + param.requires_grad = False + + def train(self, mode: bool = True) -> None: + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@MODELS.register_module() +class MoCoV3(BaseSelfSupervisor): + """MoCo v3. + + Implementation of `An Empirical Study of Training Self-Supervised Vision + Transformers `_. + + Args: + backbone (dict): Config dict for module of backbone + neck (dict): Config dict for module of deep features to compact feature + vectors. + head (dict): Config dict for module of head functions. + base_momentum (float): Momentum coefficient for the momentum-updated + encoder. Defaults to 0.01. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing + input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (Union[List[dict], dict], optional): Config dict for weight + initialization. Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + base_momentum: float = 0.01, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.momentum_encoder = CosineEMA( + nn.Sequential(self.backbone, self.neck), momentum=base_momentum) + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + view_1 = inputs[0] + view_2 = inputs[1] + + # compute query features, [N, C] each + q1 = self.neck(self.backbone(view_1))[0] + q2 = self.neck(self.backbone(view_2))[0] + + # compute key features, [N, C] each, no gradient + with torch.no_grad(): + # update momentum encoder + self.momentum_encoder.update_parameters( + nn.Sequential(self.backbone, self.neck)) + + k1 = self.momentum_encoder(view_1)[0] + k2 = self.momentum_encoder(view_2)[0] + + loss = self.head.loss(q1, k2) + self.head.loss(q2, k1) + + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/simclr.py b/mmpretrain/models/selfsup/simclr.py new file mode 100644 index 0000000..4b19ab4 --- /dev/null +++ b/mmpretrain/models/selfsup/simclr.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Tuple + +import torch +from mmengine.dist import all_gather, get_rank + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +class GatherLayer(torch.autograd.Function): + """Gather tensors from all process, supporting backward propagation.""" + + @staticmethod + def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]: + ctx.save_for_backward(input) + output = all_gather(input) + return tuple(output) + + @staticmethod + def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor: + input, = ctx.saved_tensors + grad_out = torch.zeros_like(input) + grad_out[:] = grads[get_rank()] + return grad_out + + +@MODELS.register_module() +class SimCLR(BaseSelfSupervisor): + """SimCLR. + + Implementation of `A Simple Framework for Contrastive Learning of Visual + Representations `_. + """ + + @staticmethod + def _create_buffer( + batch_size: int, device: torch.device + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the mask and the index of positive samples. + + Args: + batch_size (int): The batch size. + device (torch.device): The device of backend. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - The mask for feature selection. + - The index of positive samples. + - The mask of negative samples. + """ + mask = 1 - torch.eye(batch_size * 2, dtype=torch.uint8).to(device) + pos_idx = ( + torch.arange(batch_size * 2).to(device), + 2 * torch.arange(batch_size, dtype=torch.long).unsqueeze(1).repeat( + 1, 2).view(-1, 1).squeeze().to(device)) + neg_mask = torch.ones((batch_size * 2, batch_size * 2 - 1), + dtype=torch.uint8).to(device) + neg_mask[pos_idx] = 0 + return mask, pos_idx, neg_mask + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + inputs = torch.stack(inputs, 1) + inputs = inputs.reshape((inputs.size(0) * 2, inputs.size(2), + inputs.size(3), inputs.size(4))) + x = self.backbone(inputs) + z = self.neck(x)[0] # (2n)xd + + z = z / (torch.norm(z, p=2, dim=1, keepdim=True) + 1e-10) + z = torch.cat(GatherLayer.apply(z), dim=0) # (2N)xd + assert z.size(0) % 2 == 0 + N = z.size(0) // 2 + s = torch.matmul(z, z.permute(1, 0)) # (2N)x(2N) + mask, pos_idx, neg_mask = self._create_buffer(N, s.device) + + # remove diagonal, (2N)x(2N-1) + s = torch.masked_select(s, mask == 1).reshape(s.size(0), -1) + positive = s[pos_idx].unsqueeze(1) # (2N)x1 + + # select negative, (2N)x(2N-2) + negative = torch.masked_select(s, neg_mask == 1).reshape(s.size(0), -1) + + loss = self.head.loss(positive, negative) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/simmim.py b/mmpretrain/models/selfsup/simmim.py new file mode 100644 index 0000000..635a329 --- /dev/null +++ b/mmpretrain/models/selfsup/simmim.py @@ -0,0 +1,194 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models import SwinTransformer +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SimMIMSwinTransformer(SwinTransformer): + """Swin Transformer for SimMIM pre-training. + + Args: + Args: + arch (str | dict): Swin Transformer architecture + Defaults to 'T'. + img_size (int | tuple): The size of input image. + Defaults to 224. + in_channels (int): The num of input channels. + Defaults to 3. + drop_rate (float): Dropout rate after embedding. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. + Defaults to 0.1. + out_indices (tuple): Layers to be outputted. Defaults to (3, ). + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + norm_cfg (dict): Config dict for normalization layer at end + of backbone. Defaults to dict(type='LN') + stage_cfgs (Sequence | dict): Extra config dict for each + stage. Defaults to empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to empty dict. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'T', + img_size: Union[Tuple[int, int], int] = 224, + in_channels: int = 3, + drop_rate: float = 0., + drop_path_rate: float = 0.1, + out_indices: tuple = (3, ), + use_abs_pos_embed: bool = False, + with_cp: bool = False, + frozen_stages: bool = -1, + norm_eval: bool = False, + norm_cfg: dict = dict(type='LN'), + stage_cfgs: Union[Sequence, dict] = dict(), + patch_cfg: dict = dict(), + pad_small_map: bool = False, + init_cfg: Optional[dict] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + in_channels=in_channels, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + out_indices=out_indices, + use_abs_pos_embed=use_abs_pos_embed, + with_cp=with_cp, + frozen_stages=frozen_stages, + norm_eval=norm_eval, + norm_cfg=norm_cfg, + stage_cfgs=stage_cfgs, + patch_cfg=patch_cfg, + pad_small_map=pad_small_map, + init_cfg=init_cfg) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize weights.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + trunc_normal_(self.mask_token, mean=0, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + """Initialize weights.""" + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x: torch.Tensor, + mask: Optional[torch.Tensor]) -> Sequence[torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + not ``None``, the forward function will be executed as masked image + modeling pre-training; if the ``mask`` is ``None``, the forward + function will call ``super().forward()``, which extract features from + images without mask. + + Args: + x (torch.Tensor): Input images. + mask (torch.Tensor, optional): Masks for images. + + Returns: + tuple: A tuple containing features from multi-stages. + """ + if mask is None: + return super().forward(x) + + else: + x, hw_shape = self.patch_embed(x) + B, L, _ = x.shape + + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + stage.out_channels).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return tuple(outs) + + +@MODELS.register_module() +class SimMIM(BaseSelfSupervisor): + """SimMIM. + + Implementation of `SimMIM: A Simple Framework for Masked Image Modeling + `_. + """ + + def extract_feat(self, inputs: torch.Tensor): + return self.backbone(inputs, mask=None) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack([data_sample.mask for data_sample in data_samples]) + + img_latent = self.backbone(inputs, mask) + img_rec = self.neck(img_latent[0]) + loss = self.head.loss(img_rec, inputs, mask) + losses = dict(loss=loss) + + return losses diff --git a/mmpretrain/models/selfsup/simsiam.py b/mmpretrain/models/selfsup/simsiam.py new file mode 100644 index 0000000..a502cd7 --- /dev/null +++ b/mmpretrain/models/selfsup/simsiam.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SimSiam(BaseSelfSupervisor): + """SimSiam. + + Implementation of `Exploring Simple Siamese Representation Learning + `_. The operation of fixing learning rate + of predictor is in `engine/hooks/simsiam_hook.py`. + """ + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + img_v1 = inputs[0] + img_v2 = inputs[1] + + z1 = self.neck(self.backbone(img_v1))[0] # NxC + z2 = self.neck(self.backbone(img_v2))[0] # NxC + + loss_1 = self.head.loss(z1, z2) + loss_2 = self.head.loss(z2, z1) + + losses = dict(loss=0.5 * (loss_1 + loss_2)) + return losses diff --git a/mmpretrain/models/selfsup/spark.py b/mmpretrain/models/selfsup/spark.py new file mode 100644 index 0000000..d5570a5 --- /dev/null +++ b/mmpretrain/models/selfsup/spark.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils.norm import build_norm_layer +from ..utils.sparse_modules import SparseHelper +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SparK(BaseSelfSupervisor): + """Implementation of SparK. + + Implementation of `Designing BERT for Convolutional Networks: Sparse and + Hierarchical Masked Modeling `_. + + Modified from + https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py + """ + + def __init__( + self, + backbone: dict, + neck: dict, + head: dict, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + input_size: int = 224, + downsample_raito: int = 32, + mask_ratio: float = 0.6, + enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'), + enc_dec_norm_dim: int = 2048, + init_cfg: Optional[dict] = None, + ) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + self.input_size = input_size + self.downsample_raito = downsample_raito + feature_map_size = input_size // downsample_raito + self.feature_map_size = feature_map_size + + self.mask_ratio = mask_ratio + self.len_keep = round(feature_map_size * feature_map_size * + (1 - mask_ratio)) + + self.enc_dec_norm_cfg = enc_dec_norm_cfg + self.enc_dec_norms = nn.ModuleList() + self.enc_dec_projectors = nn.ModuleList() + self.mask_tokens = nn.ParameterList() + + proj_out_dim = self.neck.feature_dim + for i in range(len(self.backbone.out_indices)): + enc_dec_norm = build_norm_layer(self.enc_dec_norm_cfg, + enc_dec_norm_dim) + self.enc_dec_norms.append(enc_dec_norm) + + kernel_size = 1 if i <= 0 else 3 + proj_layer = nn.Conv2d( + enc_dec_norm_dim, + proj_out_dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + bias=True) + if i == 0 and enc_dec_norm_dim == proj_out_dim: + proj_layer = nn.Identity() + self.enc_dec_projectors.append(proj_layer) + + mask_token = nn.Parameter(torch.zeros(1, enc_dec_norm_dim, 1, 1)) + trunc_normal_(mask_token, mean=0, std=.02, a=-.02, b=.02) + self.mask_tokens.append(mask_token) + + enc_dec_norm_dim //= 2 + proj_out_dim //= 2 + feature_map_size *= 2 + + def mask(self, + shape: torch.Size, + device: Union[torch.device, str], + generator: Optional[torch.Generator] = None): + """Mask generation. + + Args: + shape (torch.Size): The shape of the input images. + device (Union[torch.device, str]): The device of the tensor. + generator (torch.Generator, optional): Generator for random + functions. Defaults to None + Returns: + torch.Tensor: The generated mask. + """ + B, C, H, W = shape + f = self.feature_map_size + idx = torch.rand(B, f * f, generator=generator).argsort(dim=1) + idx = idx[:, :self.len_keep].to(device) # (B, len_keep) + return torch.zeros( + B, f * f, dtype=torch.bool, device=device).scatter_( + dim=1, index=idx, value=True).view(B, 1, f, f) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + # active mask of feature map, (B, 1, f, f) + active_mask_feature_map = self.mask(inputs.shape, inputs.device) + SparseHelper._cur_active = active_mask_feature_map + + # active mask of original input, (B, 1, H, W) + active_mask_origin = active_mask_feature_map.repeat_interleave( + self.downsample_raito, + 2).repeat_interleave(self.downsample_raito, 3) + masked_img = inputs * active_mask_origin + + # get hierarchical encoded sparse features in a list + # containing four feature maps + feature_maps = self.backbone(masked_img) + + # from the smallest feature map to the largest + feature_maps = list(feature_maps) + feature_maps.reverse() + + cur_active = active_mask_feature_map + feature_maps_to_dec = [] + for i, feature_map in enumerate(feature_maps): + if feature_map is not None: + # fill in empty positions with [mask] embeddings + feature_map = self.enc_dec_norms[i](feature_map) + mask_token = self.mask_tokens[i].expand_as(feature_map) + feature_map = torch.where( + cur_active.expand_as(feature_map), feature_map, + mask_token.to(feature_map.dtype)) + feature_map = self.enc_dec_projectors[i](feature_map) + feature_maps_to_dec.append(feature_map) + + # dilate the mask map + cur_active = cur_active.repeat_interleave( + 2, dim=2).repeat_interleave( + 2, dim=3) + + # decode and reconstruct + rec_img = self.neck(feature_maps_to_dec) + + # compute loss + loss = self.head(rec_img, inputs, active_mask_feature_map) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/selfsup/swav.py b/mmpretrain/models/selfsup/swav.py new file mode 100644 index 0000000..efe0eab --- /dev/null +++ b/mmpretrain/models/selfsup/swav.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SwAV(BaseSelfSupervisor): + """SwAV. + + Implementation of `Unsupervised Learning of Visual Features by Contrasting + Cluster Assignments `_. + + The queue is built in ``mmpretrain/engine/hooks/swav_hook.py``. + """ + + def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """Forward computation during training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + assert isinstance(inputs, list) + # multi-res forward passes + idx_crops = torch.cumsum( + torch.unique_consecutive( + torch.tensor([input.shape[-1] for input in inputs]), + return_counts=True)[1], 0) + start_idx = 0 + output = [] + for end_idx in idx_crops: + _out = self.backbone(torch.cat(inputs[start_idx:end_idx])) + output.append(_out) + start_idx = end_idx + output = self.neck(output) + + loss = self.head.loss(output) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/tta/__init__.py b/mmpretrain/models/tta/__init__.py new file mode 100644 index 0000000..568e64f --- /dev/null +++ b/mmpretrain/models/tta/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .score_tta import AverageClsScoreTTA + +__all__ = ['AverageClsScoreTTA'] diff --git a/mmpretrain/models/tta/score_tta.py b/mmpretrain/models/tta/score_tta.py new file mode 100644 index 0000000..5b8a078 --- /dev/null +++ b/mmpretrain/models/tta/score_tta.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine.model import BaseTTAModel + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class AverageClsScoreTTA(BaseTTAModel): + + def merge_preds( + self, + data_samples_list: List[List[DataSample]], + ) -> List[DataSample]: + """Merge predictions of enhanced data to one prediction. + + Args: + data_samples_list (List[List[DataSample]]): List of predictions + of all enhanced data. + + Returns: + List[DataSample]: Merged prediction. + """ + merged_data_samples = [] + for data_samples in data_samples_list: + merged_data_samples.append(self._merge_single_sample(data_samples)) + return merged_data_samples + + def _merge_single_sample(self, data_samples): + merged_data_sample: DataSample = data_samples[0].new() + merged_score = sum(data_sample.pred_score + for data_sample in data_samples) / len(data_samples) + merged_data_sample.set_pred_score(merged_score) + return merged_data_sample diff --git a/mmpretrain/models/utils/__init__.py b/mmpretrain/models/utils/__init__.py new file mode 100644 index 0000000..e59d71d --- /dev/null +++ b/mmpretrain/models/utils/__init__.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpretrain.utils.dependency import WITH_MULTIMODAL +from .attention import (BEiTAttention, ChannelMultiheadAttention, + CrossMultiheadAttention, LeAttention, + MultiheadAttention, PromptMultiheadAttention, + ShiftWindowMSA, WindowMSA, WindowMSAV2) +from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix +from .batch_shuffle import batch_shuffle_ddp, batch_unshuffle_ddp +from .channel_shuffle import channel_shuffle +from .clip_generator_helper import QuickGELU, build_clip_model +from .data_preprocessor import (ClsDataPreprocessor, + MultiModalDataPreprocessor, + SelfSupDataPreprocessor, + TwoNormDataPreprocessor, VideoDataPreprocessor) +from .ema import CosineEMA +from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed, + resize_relative_position_bias_table) +from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple +from .inverted_residual import InvertedResidual +from .layer_scale import LayerScale +from .make_divisible import make_divisible +from .norm import GRN, LayerNorm2d, build_norm_layer +from .position_encoding import (ConditionalPositionEncoding, + PositionEncodingFourier, RotaryEmbeddingFast, + build_2d_sincos_position_embedding) +from .res_layer_extra_norm import ResLayerExtraNorm +from .se_layer import SELayer +from .sparse_modules import (SparseAvgPooling, SparseBatchNorm2d, SparseConv2d, + SparseHelper, SparseLayerNorm2D, SparseMaxPooling, + SparseSyncBatchNorm2d) +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .vector_quantizer import NormEMAVectorQuantizer + +__all__ = [ + 'channel_shuffle', + 'make_divisible', + 'InvertedResidual', + 'SELayer', + 'to_ntuple', + 'to_2tuple', + 'to_3tuple', + 'to_4tuple', + 'PatchEmbed', + 'PatchMerging', + 'HybridEmbed', + 'RandomBatchAugment', + 'ShiftWindowMSA', + 'is_tracing', + 'MultiheadAttention', + 'ConditionalPositionEncoding', + 'resize_pos_embed', + 'resize_relative_position_bias_table', + 'ClsDataPreprocessor', + 'Mixup', + 'CutMix', + 'ResizeMix', + 'BEiTAttention', + 'LayerScale', + 'WindowMSA', + 'WindowMSAV2', + 'ChannelMultiheadAttention', + 'PositionEncodingFourier', + 'LeAttention', + 'GRN', + 'LayerNorm2d', + 'build_norm_layer', + 'CrossMultiheadAttention', + 'build_2d_sincos_position_embedding', + 'PromptMultiheadAttention', + 'NormEMAVectorQuantizer', + 'build_clip_model', + 'batch_shuffle_ddp', + 'batch_unshuffle_ddp', + 'SelfSupDataPreprocessor', + 'TwoNormDataPreprocessor', + 'VideoDataPreprocessor', + 'CosineEMA', + 'ResLayerExtraNorm', + 'MultiModalDataPreprocessor', + 'QuickGELU', + 'SwiGLUFFN', + 'SwiGLUFFNFused', + 'RotaryEmbeddingFast', + 'SparseAvgPooling', + 'SparseConv2d', + 'SparseHelper', + 'SparseMaxPooling', + 'SparseBatchNorm2d', + 'SparseLayerNorm2D', + 'SparseSyncBatchNorm2d', +] + +if WITH_MULTIMODAL: + from .huggingface import (no_load_hf_pretrained_model, register_hf_model, + register_hf_tokenizer) + from .tokenizer import (Blip2Tokenizer, BlipTokenizer, FullTokenizer, + OFATokenizer) + + __all__.extend([ + 'BlipTokenizer', 'OFATokenizer', 'Blip2Tokenizer', 'register_hf_model', + 'register_hf_tokenizer', 'no_load_hf_pretrained_model', 'FullTokenizer' + ]) diff --git a/mmpretrain/models/utils/attention.py b/mmpretrain/models/utils/attention.py new file mode 100644 index 0000000..e92f605 --- /dev/null +++ b/mmpretrain/models/utils/attention.py @@ -0,0 +1,1129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +import warnings +from functools import partial +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.drop import build_dropout +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils import digit_version + +from mmpretrain.registry import MODELS +from .helpers import to_2tuple +from .layer_scale import LayerScale + +# After pytorch v1.10.0, use torch.meshgrid without indexing +# will raise extra warning. For more details, +# refers to https://github.com/pytorch/pytorch/issues/50276 +if digit_version(torch.__version__) >= digit_version('1.10.0'): + torch_meshgrid = partial(torch.meshgrid, indexing='ij') +else: + torch_meshgrid = torch.meshgrid + + +def scaled_dot_product_attention_pyimpl(query, + key, + value, + attn_mask=None, + dropout_p=0., + scale=None, + is_causal=False): + scale = scale or query.size(-1)**0.5 + if is_causal and attn_mask is not None: + attn_mask = torch.ones( + query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0) + if attn_mask is not None and attn_mask.dtype == torch.bool: + attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) + + attn_weight = query @ key.transpose(-2, -1) / scale + if attn_mask is not None: + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + return attn_weight @ value + + +if digit_version(torch.__version__) >= digit_version('2.0.0'): + scaled_dot_product_attention = F.scaled_dot_product_attention +else: + scaled_dot_product_attention = scaled_dot_product_attention_pyimpl + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop (float, optional): Dropout ratio of attention weight. + Defaults to 0. + proj_drop (float, optional): Dropout ratio of output. Defaults to 0. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + init_cfg=None): + + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + super(WindowMSA, self).init_weights() + + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww, + Wh*Ww), value should be between (-inf, 0]. + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class WindowMSAV2(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Based on implementation on Swin Transformer V2 original repo. Refers to + https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer_v2.py + for more details. + + Args: + embed_dims (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + attn_drop (float): Dropout ratio of attention weight. + Defaults to 0. + proj_drop (float): Dropout ratio of output. Defaults to 0. + cpb_mlp_hidden_dims (int): The hidden dimensions of the continuous + relative position bias network. Defaults to 512. + pretrained_window_size (tuple(int)): The height and width of the window + in pre-training. Defaults to (0, 0), which means not load + pretrained model. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + cpb_mlp_hidden_dims=512, + pretrained_window_size=(0, 0), + init_cfg=None): + + super().__init__(init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + + # Use small network for continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear( + in_features=2, out_features=cpb_mlp_hidden_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear( + in_features=cpb_mlp_hidden_dims, + out_features=num_heads, + bias=False)) + + # Add learnable scalar for cosine attention + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # get relative_coords_table + relative_coords_h = torch.arange( + -(self.window_size[0] - 1), + self.window_size[0], + dtype=torch.float32) + relative_coords_w = torch.arange( + -(self.window_size[1] - 1), + self.window_size[1], + dtype=torch.float32) + relative_coords_table = torch.stack( + torch_meshgrid([relative_coords_h, relative_coords_w])).permute( + 1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= ( + pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= ( + pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + self.register_buffer('relative_coords_table', relative_coords_table) + + # get pair-wise relative position index + # for each token inside the window + indexes_h = torch.arange(self.window_size[0]) + indexes_w = torch.arange(self.window_size[1]) + coordinates = torch.stack( + torch_meshgrid([indexes_h, indexes_w]), dim=0) # 2, Wh, Ww + coordinates = torch.flatten(coordinates, start_dim=1) # 2, Wh*Ww + # 2, Wh*Ww, Wh*Ww + relative_coordinates = coordinates[:, :, None] - coordinates[:, + None, :] + relative_coordinates = relative_coordinates.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + + relative_coordinates[:, :, 0] += self.window_size[ + 0] - 1 # shift to start from 0 + relative_coordinates[:, :, 1] += self.window_size[1] - 1 + relative_coordinates[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coordinates.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(embed_dims)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww, + Wh*Ww), value should be between (-inf, 0]. + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + (self.q_bias, + torch.zeros_like(self.v_bias, + requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + # cosine attention + attn = ( + F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp( + self.logit_scale, max=np.log(1. / 0.01)).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp( + self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +@MODELS.register_module() +class ShiftWindowMSA(BaseModule): + """Shift Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults to dict(type='DropPath', drop_prob=0.). + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + window_msa (Callable): To build a window multi-head attention module. + Defaults to :class:`WindowMSA`. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + **kwargs: Other keyword arguments to build the window multi-head + attention module. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + shift_size=0, + dropout_layer=dict(type='DropPath', drop_prob=0.), + pad_small_map=False, + window_msa=WindowMSA, + init_cfg=None, + **kwargs): + super().__init__(init_cfg) + + self.shift_size = shift_size + self.window_size = window_size + assert 0 <= self.shift_size < self.window_size + + self.w_msa = window_msa( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=to_2tuple(self.window_size), + **kwargs, + ) + + self.drop = build_dropout(dropout_layer) + self.pad_small_map = pad_small_map + + def forward(self, query, hw_shape): + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, f"The query length {L} doesn't match the input "\ + f'shape ({H}, {W}).' + query = query.view(B, H, W, C) + + window_size = self.window_size + shift_size = self.shift_size + + if min(H, W) == window_size: + # If not pad small feature map, avoid shifting when the window size + # is equal to the size of feature map. It's to align with the + # behavior of the original implementation. + shift_size = shift_size if self.pad_small_map else 0 + elif min(H, W) < window_size: + # In the original implementation, the window size will be shrunk + # to the size of feature map. The behavior is different with + # swin-transformer for downstream tasks. To support dynamic input + # shape, we don't allow this feature. + assert self.pad_small_map, \ + f'The input shape ({H}, {W}) is smaller than the window ' \ + f'size ({window_size}). Please set `pad_small_map=True`, or ' \ + 'decrease the `window_size`.' + + pad_r = (window_size - W % window_size) % window_size + pad_b = (window_size - H % window_size) % window_size + query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if shift_size > 0: + query = torch.roll( + query, shifts=(-shift_size, -shift_size), dims=(1, 2)) + + attn_mask = self.get_attn_mask((H_pad, W_pad), + window_size=window_size, + shift_size=shift_size, + device=query.device) + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(query, window_size) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, window_size, window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad, + window_size) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) + else: + x = shifted_x + + if H != H_pad or W != W_pad: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + + return x + + @staticmethod + def window_reverse(windows, H, W, window_size): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + @staticmethod + def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + @staticmethod + def get_attn_mask(hw_shape, window_size, shift_size, device=None): + if shift_size > 0: + img_mask = torch.zeros(1, *hw_shape, 1, device=device) + h_slices = (slice(0, -window_size), slice(-window_size, + -shift_size), + slice(-shift_size, None)) + w_slices = (slice(0, -window_size), slice(-window_size, + -shift_size), + slice(-shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = ShiftWindowMSA.window_partition( + img_mask, window_size) + mask_windows = mask_windows.view(-1, window_size * window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0) + attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0) + else: + attn_mask = None + return attn_mask + + +class MultiheadAttention(BaseModule): + """Multi-head Attention Module. + + This module implements multi-head attention that supports different input + dims and embed dims. And it also supports a shortcut from ``value``, which + is useful if input dims is not the same with embed dims. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shortcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + use_layer_scale (bool): Whether to use layer scale. Defaults to False. + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 0. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + input_dims=None, + attn_drop=0., + proj_drop=0., + dropout_layer=dict(type='Dropout', drop_prob=0.), + qkv_bias=True, + qk_scale=None, + proj_bias=True, + v_shortcut=False, + use_layer_scale=False, + layer_scale_init_value=0., + init_cfg=None): + super(MultiheadAttention, self).__init__(init_cfg=init_cfg) + + self.input_dims = input_dims or embed_dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.v_shortcut = v_shortcut + + self.head_dims = embed_dims // num_heads + if qk_scale is not None: + self.scaled_dot_product_attention = partial( + scaled_dot_product_attention_pyimpl, + scale=self.head_dims**-0.5) + else: + self.scaled_dot_product_attention = scaled_dot_product_attention + + self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.out_drop = build_dropout(dropout_layer) + + if use_layer_scale: + warnings.warn('The `use_layer_scale` in `MultiheadAttention` will ' + 'be deprecated. Please use `layer_scale_init_value` ' + 'to control whether using layer scale or not.') + + if use_layer_scale or (layer_scale_init_value > 0): + layer_scale_init_value = layer_scale_init_value or 1e-5 + self.gamma1 = LayerScale( + embed_dims, layer_scale_init_value=layer_scale_init_value) + else: + self.gamma1 = nn.Identity() + + def forward(self, x): + B, N, _ = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn_drop = self.attn_drop if self.training else 0. + x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) + + x = self.proj(x) + x = self.out_drop(self.gamma1(self.proj_drop(x))) + + if self.v_shortcut: + x = v.squeeze(1) + x + return x + + +class BEiTAttention(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + The initial implementation is in MMSegmentation. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int, int]): The height and width of the window. + use_rel_pos_bias (bool): Whether to use unique relative position bias, + if False, use shared relative position bias defined in backbone. + bias (str): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + use_rel_pos_bias, + bias='qv_bias', + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.bias = bias + self.scale = qk_scale or head_embed_dims**-0.5 + + qkv_bias = bias + if bias == 'qv_bias': + self._init_qv_bias() + qkv_bias = False + + if window_size is None: + assert not use_rel_pos_bias + else: + assert isinstance(window_size, tuple) + self.window_size = window_size + self.use_rel_pos_bias = use_rel_pos_bias + self._init_rel_pos_embedding() + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + def _init_qv_bias(self): + self.q_bias = nn.Parameter(torch.zeros(self.embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(self.embed_dims)) + + def _init_rel_pos_embedding(self): + if self.use_rel_pos_bias: + Wh, Ww = self.window_size + # cls to token & token 2 cls & cls to cls + self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3 + # relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH) + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, self.num_heads)) + + # get pair-wise relative position index for + # each token inside the window + coords_h = torch.arange(Wh) + coords_w = torch.arange(Ww) + # coords shape is (2, Wh, Ww) + coords = torch.stack(torch_meshgrid([coords_h, coords_w])) + # coords_flatten shape is (2, Wh*Ww) + coords_flatten = torch.flatten(coords, 1) + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :]) + # relative_coords shape is (Wh*Ww, Wh*Ww, 2) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + # shift to start from 0 + relative_coords[:, :, 0] += Wh - 1 + relative_coords[:, :, 1] += Ww - 1 + relative_coords[:, :, 0] *= 2 * Ww - 1 + relative_position_index = torch.zeros( + size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype) + # relative_position_index shape is (Wh*Ww, Wh*Ww) + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer('relative_position_index', + relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + def init_weights(self): + super().init_weights() + if self.use_rel_pos_bias: + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, rel_pos_bias=None): + """ + Args: + x (tensor): input features with shape of (num_windows*B, N, C). + rel_pos_bias (tensor): input relative position bias with shape of + (num_heads, N, N). + """ + B, N, C = x.shape + + if self.bias == 'qv_bias': + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + else: + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + Wh = self.window_size[0] + Ww = self.window_size[1] + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + Wh * Ww + 1, Wh * Ww + 1, -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + # use shared relative position bias + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class ChannelMultiheadAttention(BaseModule): + """Channel Multihead Self-attention Module. + + This module implements channel multi-head attention that supports different + input dims and embed dims. + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shoutcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to False. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + qk_scale_type (str): The scale type of qk scale. + Defaults to 'learnable'. It can be 'learnable', 'fixed' or 'none'. + qk_scale (float, optional): If set qk_scale_type to 'none', this + should be specified with valid float number. Defaults to None. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + input_dims=None, + attn_drop=0., + proj_drop=0., + dropout_layer=dict(type='Dropout', drop_prob=0.), + qkv_bias=False, + proj_bias=True, + qk_scale_type='learnable', + qk_scale=None, + v_shortcut=False, + init_cfg=None): + super().__init__(init_cfg) + + self.input_dims = input_dims or embed_dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.v_shortcut = v_shortcut + + self.head_dims = embed_dims // num_heads + if qk_scale_type == 'learnable': + self.scale = nn.Parameter(torch.ones(num_heads, 1, 1)) + elif qk_scale_type == 'fixed': + self.scale = self.head_dims**-0.5 + elif qk_scale_type == 'none': + assert qk_scale is not None + self.scale = qk_scale + + self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.out_drop = build_dropout(dropout_layer) + + def forward(self, x): + B, N, _ = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + + q, k, v = [item.transpose(-2, -1) for item in [qkv[0], qkv[1], qkv[2]]] + + q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, self.embed_dims) + x = self.proj(x) + x = self.out_drop(self.proj_drop(x)) + + if self.v_shortcut: + x = qkv[2].squeeze(1) + x + return x + + +class LeAttention(BaseModule): + """LeViT Attention. Multi-head attention with attention bias, which is + proposed in `LeViT: a Vision Transformer in ConvNet’s Clothing for Faster + Inference`_ + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8. + key_dim (int): Dimension of key. Default: None. + attn_ratio (int): Ratio of attention heads. Default: 8. + resolution (tuple[int]): Input resolution. Default: (16, 16). + init_cfg (dict, optional): The Config for initialization. + """ + + def __init__(self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list( + itertools.product(range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer( + 'attention_bias_idxs', + torch.LongTensor(idxs).view(N, N), + persistent=False) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, + -1).split([self.key_dim, self.key_dim, self.d], + dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ((q @ k.transpose(-2, -1)) * self.scale + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab)) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class CrossMultiheadAttention(BaseModule): + """Cross attention between queries and the union of keys and values. + + This module is different from ``MultiheadAttention``, for the attention + is computed between queries and the union of keys and values. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + """ + + def __init__(self, + embed_dims: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0.) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(embed_dims, embed_dims, bias=False) + self.k = nn.Linear(embed_dims, embed_dims, bias=False) + self.v = nn.Linear(embed_dims, embed_dims, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(embed_dims)) + else: + self.q_bias = None + self.k_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, + x: torch.Tensor, + k: torch.Tensor = None, + v: torch.Tensor = None) -> None: + """Forward function.""" + B, N, _ = x.shape + + N_k = k.shape[1] + N_v = v.shape[1] + + q_bias, k_bias, v_bias = None, None, None + if self.q_bias is not None: + q_bias = self.q_bias + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + v_bias = self.v_bias + + q = F.linear( + input=x, weight=self.q.weight, bias=q_bias) # (B, N_q, dim) + k = F.linear( + input=k, weight=self.k.weight, bias=k_bias) # (B, N_k, dim) + v = F.linear(input=v, weight=self.v.weight, bias=v_bias) + + q = q.reshape(B, N, 1, self.num_heads, + -1).permute(2, 0, 3, 1, + 4).squeeze(0) # (B, num_heads, N_q, dim) + k = k.reshape(B, N_k, 1, self.num_heads, + -1).permute(2, 0, 3, 1, + 4).squeeze(0) # (B, num_heads, N_k, dim) + v = v.reshape(B, N_v, 1, self.num_heads, + -1).permute(2, 0, 3, 1, + 4).squeeze(0) # (B, num_heads, N_v, dim) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class PromptMultiheadAttention(MultiheadAttention): + """Prompt Multihead Attention for MILAN. + + This module is specific for the prompt encoder in MILAN. It will not update + the visible tokens from the encoder. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + input_dims (int, optional): The input dimension, and if None, + use ``embed_dims``. Defaults to None. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + dropout_layer (dict): The dropout config before adding the shortcut. + Defaults to ``dict(type='Dropout', drop_prob=0.)``. + qkv_bias (bool): If True, add a learnable bias to q, k, v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + v_shortcut (bool): Add a shortcut from value to output. It's usually + used if ``input_dims`` is different from ``embed_dims``. + Defaults to False. + return_attention (bool): If True, return the attention map, computed by + the cross attention between the class token and all other tokens. + Defaults to False. + init_cfg (Union[List[dict], dict], optional): The Config for + initialization. Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + input_dims: Optional[int] = None, + attn_drop: float = 0, + proj_drop: float = 0, + dropout_layer: dict = dict(type='Dropout', drop_prob=0.), + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + proj_bias: bool = True, + v_shortcut: bool = False, + use_layer_scale: bool = False, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + input_dims=input_dims, + attn_drop=attn_drop, + proj_drop=proj_drop, + dropout_layer=dropout_layer, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + proj_bias=proj_bias, + v_shortcut=v_shortcut, + use_layer_scale=use_layer_scale, + init_cfg=init_cfg) + # no longer need qkv + del self.qkv + + # to project the mask tokens + self.q = nn.Linear(embed_dims, embed_dims, bias=qkv_bias) + # to project al the tokens + self.kv = nn.Linear(embed_dims, embed_dims * 2, bias=qkv_bias) + + def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor, + ids_restore: torch.Tensor) -> torch.Tensor: + """Forward function for `PromptMultiheadAttention`. + + Args: + x (torch.Tensor): Mask token features with shape N x L_m x C. + visible_tokens (torch.Tensor): The visible tokens features from + encoder with shape N x L_v x C. + ids_restore (torch.Tensor): The ids of all tokens in the original + image with shape N x L. + + Returns: + torch Tensor: Output features with shape N x L x C. + """ + x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1) + assert x_.shape[1] == ids_restore.shape[1] + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[-1])) + x_ = torch.cat([visible_tokens[:, :1, :], x_], dim=1) + + # full sequence shape + B, _, _ = x_.shape + q = self.q(x).reshape(B, x.shape[1], self.num_heads, + self.head_dims).permute(0, 2, 1, 3) + kv = self.kv(x_).reshape(B, x_.shape[1], 2, self.num_heads, + self.head_dims).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn_drop = self.attn_drop if self.training else 0. + attn = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = attn.transpose(1, 2).reshape(B, x.shape[1], self.embed_dims) + + x = self.proj(x) + x = self.out_drop(self.gamma1(self.proj_drop(x))) + return x diff --git a/mmpretrain/models/utils/batch_augments/__init__.py b/mmpretrain/models/utils/batch_augments/__init__.py new file mode 100644 index 0000000..2fbc4e1 --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .cutmix import CutMix +from .mixup import Mixup +from .resizemix import ResizeMix +from .wrapper import RandomBatchAugment + +__all__ = ('RandomBatchAugment', 'CutMix', 'Mixup', 'ResizeMix') diff --git a/mmpretrain/models/utils/batch_augments/cutmix.py b/mmpretrain/models/utils/batch_augments/cutmix.py new file mode 100644 index 0000000..665427b --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/cutmix.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import numpy as np +import torch + +from mmpretrain.registry import BATCH_AUGMENTS +from .mixup import Mixup + + +@BATCH_AUGMENTS.register_module() +class CutMix(Mixup): + r"""CutMix batch agumentation. + + CutMix is a method to improve the network's generalization capability. It's + proposed in `CutMix: Regularization Strategy to Train Strong Classifiers + with Localizable Features ` + + With this method, patches are cut and pasted among training images where + the ground truth labels are also mixed proportionally to the area of the + patches. + + Args: + alpha (float): Parameters for Beta distribution to generate the + mixing ratio. It should be a positive number. More details + can be found in :class:`Mixup`. + cutmix_minmax (List[float], optional): The min/max area ratio of the + patches. If not None, the bounding-box of patches is uniform + sampled within this ratio range, and the ``alpha`` will be ignored. + Otherwise, the bounding-box is generated according to the + ``alpha``. Defaults to None. + correct_lam (bool): Whether to apply lambda correction when cutmix bbox + clipped by image borders. Defaults to True. + + .. note :: + If the ``cutmix_minmax`` is None, how to generate the bounding-box of + patches according to the ``alpha``? + + First, generate a :math:`\lambda`, details can be found in + :class:`Mixup`. And then, the area ratio of the bounding-box + is calculated by: + + .. math:: + \text{ratio} = \sqrt{1-\lambda} + """ + + def __init__(self, + alpha: float, + cutmix_minmax: Optional[List[float]] = None, + correct_lam: bool = True): + super().__init__(alpha=alpha) + + self.cutmix_minmax = cutmix_minmax + self.correct_lam = correct_lam + + def rand_bbox_minmax( + self, + img_shape: Tuple[int, int], + count: Optional[int] = None) -> Tuple[int, int, int, int]: + """Min-Max CutMix bounding-box Inspired by Darknet cutmix + implementation. It generates a random rectangular bbox based on min/max + percent values applied to each dimension of the input image. + + Typical defaults for minmax are usually in the .2-.3 for min and + .8-.9 range for max. + + Args: + img_shape (tuple): Image shape as tuple + count (int, optional): Number of bbox to generate. Defaults to None + """ + assert len(self.cutmix_minmax) == 2 + img_h, img_w = img_shape + cut_h = np.random.randint( + int(img_h * self.cutmix_minmax[0]), + int(img_h * self.cutmix_minmax[1]), + size=count) + cut_w = np.random.randint( + int(img_w * self.cutmix_minmax[0]), + int(img_w * self.cutmix_minmax[1]), + size=count) + yl = np.random.randint(0, img_h - cut_h, size=count) + xl = np.random.randint(0, img_w - cut_w, size=count) + yu = yl + cut_h + xu = xl + cut_w + return yl, yu, xl, xu + + def rand_bbox(self, + img_shape: Tuple[int, int], + lam: float, + margin: float = 0., + count: Optional[int] = None) -> Tuple[int, int, int, int]: + """Standard CutMix bounding-box that generates a random square bbox + based on lambda value. This implementation includes support for + enforcing a border margin as percent of bbox dimensions. + + Args: + img_shape (tuple): Image shape as tuple + lam (float): Cutmix lambda value + margin (float): Percentage of bbox dimension to enforce as margin + (reduce amount of box outside image). Defaults to 0. + count (int, optional): Number of bbox to generate. Defaults to None + """ + ratio = np.sqrt(1 - lam) + img_h, img_w = img_shape + cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) + margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) + cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) + cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) + yl = np.clip(cy - cut_h // 2, 0, img_h) + yh = np.clip(cy + cut_h // 2, 0, img_h) + xl = np.clip(cx - cut_w // 2, 0, img_w) + xh = np.clip(cx + cut_w // 2, 0, img_w) + return yl, yh, xl, xh + + def cutmix_bbox_and_lam(self, + img_shape: Tuple[int, int], + lam: float, + count: Optional[int] = None) -> tuple: + """Generate bbox and apply lambda correction. + + Args: + img_shape (tuple): Image shape as tuple + lam (float): Cutmix lambda value + count (int, optional): Number of bbox to generate. Defaults to None + """ + if self.cutmix_minmax is not None: + yl, yu, xl, xu = self.rand_bbox_minmax(img_shape, count=count) + else: + yl, yu, xl, xu = self.rand_bbox(img_shape, lam, count=count) + if self.correct_lam or self.cutmix_minmax is not None: + bbox_area = (yu - yl) * (xu - xl) + lam = 1. - bbox_area / float(img_shape[0] * img_shape[1]) + return (yl, yu, xl, xu), lam + + def mix(self, batch_inputs: torch.Tensor, + batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ + lam = np.random.beta(self.alpha, self.alpha) + batch_size = batch_inputs.size(0) + img_shape = batch_inputs.shape[-2:] + index = torch.randperm(batch_size) + + (y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam) + batch_inputs[:, :, y1:y2, x1:x2] = batch_inputs[index, :, y1:y2, x1:x2] + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] + + return batch_inputs, mixed_scores diff --git a/mmpretrain/models/utils/batch_augments/mixup.py b/mmpretrain/models/utils/batch_augments/mixup.py new file mode 100644 index 0000000..bedb2c3 --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/mixup.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import numpy as np +import torch + +from mmpretrain.registry import BATCH_AUGMENTS + + +@BATCH_AUGMENTS.register_module() +class Mixup: + r"""Mixup batch augmentation. + + Mixup is a method to reduces the memorization of corrupt labels and + increases the robustness to adversarial examples. It's proposed in + `mixup: Beyond Empirical Risk Minimization + `_ + + Args: + alpha (float): Parameters for Beta distribution to generate the + mixing ratio. It should be a positive number. More details + are in the note. + + Note: + The :math:`\alpha` (``alpha``) determines a random distribution + :math:`Beta(\alpha, \alpha)`. For each batch of data, we sample + a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random + distribution. + """ + + def __init__(self, alpha: float): + assert isinstance(alpha, float) and alpha > 0 + + self.alpha = alpha + + def mix(self, batch_inputs: torch.Tensor, + batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ + lam = np.random.beta(self.alpha, self.alpha) + batch_size = batch_inputs.size(0) + index = torch.randperm(batch_size) + + mixed_inputs = lam * batch_inputs + (1 - lam) * batch_inputs[index, :] + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] + + return mixed_inputs, mixed_scores + + def __call__(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor): + """Mix the batch inputs and batch data samples.""" + assert batch_score.ndim == 2, \ + 'The input `batch_score` should be a one-hot format tensor, '\ + 'which shape should be ``(N, num_classes)``.' + + mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score.float()) + return mixed_inputs, mixed_score diff --git a/mmpretrain/models/utils/batch_augments/resizemix.py b/mmpretrain/models/utils/batch_augments/resizemix.py new file mode 100644 index 0000000..c70f81b --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/resizemix.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from mmpretrain.registry import BATCH_AUGMENTS +from .cutmix import CutMix + + +@BATCH_AUGMENTS.register_module() +class ResizeMix(CutMix): + r"""ResizeMix Random Paste layer for a batch of data. + + The ResizeMix will resize an image to a small patch and paste it on another + image. It's proposed in `ResizeMix: Mixing Data with Preserved Object + Information and True Labels `_ + + Args: + alpha (float): Parameters for Beta distribution to generate the + mixing ratio. It should be a positive number. More details + can be found in :class:`Mixup`. + lam_min(float): The minimum value of lam. Defaults to 0.1. + lam_max(float): The maximum value of lam. Defaults to 0.8. + interpolation (str): algorithm used for upsampling: + 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | + 'area'. Defaults to 'bilinear'. + prob (float): The probability to execute resizemix. It should be in + range [0, 1]. Defaults to 1.0. + cutmix_minmax (List[float], optional): The min/max area ratio of the + patches. If not None, the bounding-box of patches is uniform + sampled within this ratio range, and the ``alpha`` will be ignored. + Otherwise, the bounding-box is generated according to the + ``alpha``. Defaults to None. + correct_lam (bool): Whether to apply lambda correction when cutmix bbox + clipped by image borders. Defaults to True + **kwargs: Any other parameters accpeted by :class:`CutMix`. + + Note: + The :math:`\lambda` (``lam``) is the mixing ratio. It's a random + variable which follows :math:`Beta(\alpha, \alpha)` and is mapped + to the range [``lam_min``, ``lam_max``]. + + .. math:: + \lambda = \frac{Beta(\alpha, \alpha)} + {\lambda_{max} - \lambda_{min}} + \lambda_{min} + + And the resize ratio of source images is calculated by :math:`\lambda`: + + .. math:: + \text{ratio} = \sqrt{1-\lambda} + """ + + def __init__(self, + alpha: float, + lam_min: float = 0.1, + lam_max: float = 0.8, + interpolation: str = 'bilinear', + cutmix_minmax: Optional[List[float]] = None, + correct_lam: bool = True): + super().__init__( + alpha=alpha, cutmix_minmax=cutmix_minmax, correct_lam=correct_lam) + self.lam_min = lam_min + self.lam_max = lam_max + self.interpolation = interpolation + + def mix(self, batch_inputs: torch.Tensor, + batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ + lam = np.random.beta(self.alpha, self.alpha) + lam = lam * (self.lam_max - self.lam_min) + self.lam_min + img_shape = batch_inputs.shape[-2:] + batch_size = batch_inputs.size(0) + index = torch.randperm(batch_size) + + (y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam) + batch_inputs[:, :, y1:y2, x1:x2] = F.interpolate( + batch_inputs[index], + size=(int(y2 - y1), int(x2 - x1)), + mode=self.interpolation, + align_corners=False) + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] + + return batch_inputs, mixed_scores diff --git a/mmpretrain/models/utils/batch_augments/wrapper.py b/mmpretrain/models/utils/batch_augments/wrapper.py new file mode 100644 index 0000000..10e5304 --- /dev/null +++ b/mmpretrain/models/utils/batch_augments/wrapper.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Union + +import numpy as np +import torch + +from mmpretrain.registry import BATCH_AUGMENTS + + +class RandomBatchAugment: + """Randomly choose one batch augmentation to apply. + + Args: + augments (Callable | dict | list): configs of batch + augmentations. + probs (float | List[float] | None): The probabilities of each batch + augmentations. If None, choose evenly. Defaults to None. + + Example: + >>> import torch + >>> import torch.nn.functional as F + >>> from mmpretrain.models import RandomBatchAugment + >>> augments_cfg = [ + ... dict(type='CutMix', alpha=1.), + ... dict(type='Mixup', alpha=1.) + ... ] + >>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3]) + >>> imgs = torch.rand(16, 3, 32, 32) + >>> label = F.one_hot(torch.randint(0, 10, (16, )), num_classes=10) + >>> imgs, label = batch_augment(imgs, label) + + .. note :: + + To decide which batch augmentation will be used, it picks one of + ``augments`` based on the probabilities. In the example above, the + probability to use CutMix is 0.5, to use Mixup is 0.3, and to do + nothing is 0.2. + """ + + def __init__(self, augments: Union[Callable, dict, list], probs=None): + if not isinstance(augments, (tuple, list)): + augments = [augments] + + self.augments = [] + for aug in augments: + if isinstance(aug, dict): + self.augments.append(BATCH_AUGMENTS.build(aug)) + else: + self.augments.append(aug) + + if isinstance(probs, float): + probs = [probs] + + if probs is not None: + assert len(augments) == len(probs), \ + '``augments`` and ``probs`` must have same lengths. ' \ + f'Got {len(augments)} vs {len(probs)}.' + assert sum(probs) <= 1, \ + 'The total probability of batch augments exceeds 1.' + self.augments.append(None) + probs.append(1 - sum(probs)) + + self.probs = probs + + def __call__(self, batch_input: torch.Tensor, batch_score: torch.Tensor): + """Randomly apply batch augmentations to the batch inputs and batch + data samples.""" + aug_index = np.random.choice(len(self.augments), p=self.probs) + aug = self.augments[aug_index] + + if aug is not None: + return aug(batch_input, batch_score) + else: + return batch_input, batch_score.float() diff --git a/mmpretrain/models/utils/batch_shuffle.py b/mmpretrain/models/utils/batch_shuffle.py new file mode 100644 index 0000000..a0b03c5 --- /dev/null +++ b/mmpretrain/models/utils/batch_shuffle.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from mmengine.dist import all_gather, broadcast, get_rank + + +@torch.no_grad() +def batch_shuffle_ddp(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Batch shuffle, for making use of BatchNorm. + + Args: + x (torch.Tensor): Data in each GPU. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Output of shuffle operation. + - x_gather[idx_this]: Shuffled data. + - idx_unshuffle: Index for restoring. + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = torch.cat(all_gather(x), dim=0) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all) + + # broadcast to all gpus + broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + + +@torch.no_grad() +def batch_unshuffle_ddp(x: torch.Tensor, + idx_unshuffle: torch.Tensor) -> torch.Tensor: + """Undo batch shuffle. + + Args: + x (torch.Tensor): Data in each GPU. + idx_unshuffle (torch.Tensor): Index for restoring. + + Returns: + torch.Tensor: Output of unshuffle operation. + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = torch.cat(all_gather(x), dim=0) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this] diff --git a/mmpretrain/models/utils/box_utils.py b/mmpretrain/models/utils/box_utils.py new file mode 100644 index 0000000..79db516 --- /dev/null +++ b/mmpretrain/models/utils/box_utils.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torchvision.ops.boxes as boxes + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2.0, (y0 + y1) / 2.0, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +def box_iou(boxes1, boxes2): + """Return intersection-over-union (Jaccard index) between two sets of + boxes. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise IoU values for + every element in boxes1 and boxes2 + """ + return boxes.box_iou(boxes1, boxes2) + + +def generalized_box_iou(boxes1, boxes2): + """Return generalized intersection-over-union (Jaccard index) between two + sets of boxes. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU + values for every element in boxes1 and boxes2 + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + + return boxes.generalized_box_iou(boxes1, boxes2) diff --git a/mmpretrain/models/utils/channel_shuffle.py b/mmpretrain/models/utils/channel_shuffle.py new file mode 100644 index 0000000..27006a8 --- /dev/null +++ b/mmpretrain/models/utils/channel_shuffle.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def channel_shuffle(x, groups): + """Channel Shuffle operation. + + This function enables cross-group information flow for multiple groups + convolution layers. + + Args: + x (Tensor): The input tensor. + groups (int): The number of groups to divide the input tensor + in the channel dimension. + + Returns: + Tensor: The output tensor after channel shuffle operation. + """ + + batch_size, num_channels, height, width = x.size() + assert (num_channels % groups == 0), ('num_channels should be ' + 'divisible by groups') + channels_per_group = num_channels // groups + + x = x.view(batch_size, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + x = x.view(batch_size, -1, height, width) + + return x diff --git a/mmpretrain/models/utils/clip_generator_helper.py b/mmpretrain/models/utils/clip_generator_helper.py new file mode 100644 index 0000000..4f67f0e --- /dev/null +++ b/mmpretrain/models/utils/clip_generator_helper.py @@ -0,0 +1,394 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/zejiangh/MILAN +from collections import OrderedDict +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.logging import MMLogger +from torch import nn + +from mmpretrain.registry import MODELS + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +@MODELS.register_module() +class QuickGELU(nn.Module): + """A faster version of GELU.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + """Residual Attention Block (RAB). + + This module implements the same function as the MultiheadAttention, + but with a different interface, which is mainly used + in CLIP. + + Args: + d_model (int): The feature dimension. + n_head (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + Defaults to None. + """ + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: Optional[torch.Tensor] = None, + return_attention: bool = False) -> None: + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + self.return_attention = return_attention + + def attention(self, x: torch.Tensor) -> torch.Tensor: + """Attention function.""" + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + if self.return_attention: + return self.attn( + x, + x, + x, + need_weights=self.return_attention, + attn_mask=self.attn_mask) + else: + return self.attn( + x, + x, + x, + need_weights=self.return_attention, + attn_mask=self.attn_mask)[0] + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Forward function.""" + if self.return_attention: + x_, attention = self.attention(self.ln_1(x)) + x = x + x_ + x = x + self.mlp(self.ln_2(x)) + return x, attention + else: + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + """Transformer. + + Both visual and text branches use this transformer. + + Args: + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + """ + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: Optional[torch.Tensor] = None) -> None: + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList() + for _ in range(layers - 1): + self.resblocks.append( + ResidualAttentionBlock(width, heads, attn_mask)) + self.resblocks.append( + ResidualAttentionBlock( + width, heads, attn_mask, return_attention=True)) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward function.""" + z = [] + for idx, blk in enumerate(self.resblocks): + if idx < self.layers - 1: + x = blk(x) + z.append(x.permute(1, 0, 2)) + else: + x, attention = blk(x) + z.append(x.permute(1, 0, 2)) + return x, attention, z + + +class VisionTransformer(nn.Module): + """Vision Transformer for CLIP. + + Args: + input_resolution (int): The image size. + patch_size (int): The patch size. + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + out_dim (int): The output dimension. + fineturn (bool): Whether to fineturn the model. + average_target (bool): Whether to average the target. + """ + + def __init__(self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + finetune=False, + average_targets: int = 1) -> None: + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.finetune = finetune + if finetune is False: + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.average_targets = average_targets + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function.""" + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([ + self.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x + ], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x, attention, z = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x) + if self.proj is not None: + x = x @ self.proj + + return x, attention + + +class CLIP(nn.Module): + """CLIP. + + Args: + embed_dim (int): The embedding dimension. + image_resolution (int): The image size. + vision_layers (int): The number of layers in the vision transformer. + vision_width (int): The feature dimension in the vision transformer. + vision_patch_size (int): The patch size in the vision transformer. + context_length (int): The context length. + vocab_size (int): The vocabulary size. + transformer_width (int): The feature dimension in the text transformer. + transformer_heads (int): The number of attention heads in the + text transformer. + transformer_layers (int): The number of layers in the text transformer. + fineturn (bool): Whether to fineturn the model. + average_target (bool): Whether to average the target. + """ + + def __init__( + self, + embed_dim: int, + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + finetune: bool = False, + average_targets: int = 1, + ) -> None: + super().__init__() + + self.context_length = context_length + + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + finetune=finetune, + average_targets=average_targets, + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self) -> None: + """Initialize the parameters. + + The pretrained weight will override the initialized parameters by this + function. + """ + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self) -> torch.Tensor: + """Build the attention mask.""" + # lazily create causal attention mask, with full attention between the + # vision tokens pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self) -> torch.dtype: + """Get the dtype.""" + return self.visual.conv1.weight.dtype + + def encode_image(self, + image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode the image. + + Get the feature and attention mask from the last layer of the visual + branch of CLIP. + + Args: + image (torch.Tensor): The image tensor with shape NCHW. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The feature and attention mask. + """ + return self.visual(image.type(self.dtype)) + + +def build_clip_model(state_dict: dict, + finetune: bool = False, + average_targets: int = 1) -> nn.Module: + """Build the CLIP model. + + Args: + state_dict (dict): The pretrained state dict. + finetune (bool): Whether to fineturn the model. + average_targets (bool): Whether to average the target. + + Returns: + nn.Module: The CLIP model. + """ + vit = 'visual.proj' in state_dict + + if vit: + vision_width = state_dict['visual.conv1.weight'].shape[0] + vision_layers = len([ + k for k in state_dict.keys() + if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') + ]) + vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] + grid_size = round( + (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) + image_resolution = vision_patch_size * grid_size + + embed_dim = state_dict['text_projection'].shape[1] + context_length = state_dict['positional_embedding'].shape[0] + vocab_size = state_dict['token_embedding.weight'].shape[0] + transformer_width = state_dict['ln_final.weight'].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split('.')[2] for k in state_dict + if k.startswith('transformer.resblocks'))) + + model = CLIP( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + finetune, + average_targets, + ) + + for key in ['input_resolution', 'context_length', 'vocab_size']: + if key in state_dict: + del state_dict[key] + + msg = model.load_state_dict(state_dict, strict=False) + MMLogger.get_current_instance().info(f'Load CLIP model: {msg}') + return model.eval() diff --git a/mmpretrain/models/utils/data_preprocessor.py b/mmpretrain/models/utils/data_preprocessor.py new file mode 100644 index 0000000..c407bd4 --- /dev/null +++ b/mmpretrain/models/utils/data_preprocessor.py @@ -0,0 +1,620 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from numbers import Number +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from mmengine.model import (BaseDataPreprocessor, ImgDataPreprocessor, + stack_batch) + +from mmpretrain.registry import MODELS +from mmpretrain.structures import (DataSample, MultiTaskDataSample, + batch_label_to_onehot, cat_batch_labels, + tensor_split) +from .batch_augments import RandomBatchAugment + + +@MODELS.register_module() +class ClsDataPreprocessor(BaseDataPreprocessor): + """Image pre-processor for classification tasks. + + Comparing with the :class:`mmengine.model.ImgDataPreprocessor`, + + 1. It won't do normalization if ``mean`` is not specified. + 2. It does normalization and color space conversion after stacking batch. + 3. It supports batch augmentations like mixup and cutmix. + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + - Do batch augmentations like Mixup and Cutmix during training. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (Number): The padded pixel value. Defaults to 0. + to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + to_onehot (bool): Whether to generate one-hot format gt-labels and set + to data samples. Defaults to False. + num_classes (int, optional): The number of classes. Defaults to None. + batch_augments (dict, optional): The batch augmentations settings, + including "augments" and "probs". For more details, see + :class:`mmpretrain.models.RandomBatchAugment`. + """ + + def __init__(self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Number = 0, + to_rgb: bool = False, + to_onehot: bool = False, + num_classes: Optional[int] = None, + batch_augments: Optional[dict] = None): + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + self.to_onehot = to_onehot + self.num_classes = num_classes + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both `mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + self.register_buffer('mean', + torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(-1, 1, 1), False) + else: + self._enable_normalize = False + + if batch_augments: + self.batch_augments = RandomBatchAugment(**batch_augments) + if not self.to_onehot: + from mmengine.logging import MMLogger + MMLogger.get_current_instance().info( + 'Because batch augmentations are enabled, the data ' + 'preprocessor automatically enables the `to_onehot` ' + 'option to generate one-hot format labels.') + self.to_onehot = True + else: + self.batch_augments = None + + def forward(self, data: dict, training: bool = False) -> dict: + """Perform normalization, padding, bgr2rgb conversion and batch + augmentation based on ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + dict: Data in the same format as the model input. + """ + inputs = self.cast_data(data['inputs']) + + if isinstance(inputs, torch.Tensor): + # The branch if use `default_collate` as the collate_fn in the + # dataloader. + + # ------ To RGB ------ + if self.to_rgb and inputs.size(1) == 3: + inputs = inputs.flip(1) + + # -- Normalization --- + inputs = inputs.float() + if self._enable_normalize: + inputs = (inputs - self.mean) / self.std + + # ------ Padding ----- + if self.pad_size_divisor > 1: + h, w = inputs.shape[-2:] + + target_h = math.ceil( + h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil( + w / self.pad_size_divisor) * self.pad_size_divisor + pad_h = target_h - h + pad_w = target_w - w + inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant', + self.pad_value) + else: + # The branch if use `pseudo_collate` as the collate_fn in the + # dataloader. + + processed_inputs = [] + for input_ in inputs: + # ------ To RGB ------ + if self.to_rgb and input_.size(0) == 3: + input_ = input_.flip(0) + + # -- Normalization --- + input_ = input_.float() + if self._enable_normalize: + input_ = (input_ - self.mean) / self.std + + processed_inputs.append(input_) + # Combine padding and stack + inputs = stack_batch(processed_inputs, self.pad_size_divisor, + self.pad_value) + + data_samples = data.get('data_samples', None) + sample_item = data_samples[0] if data_samples is not None else None + + if isinstance(sample_item, DataSample): + batch_label = None + batch_score = None + + if 'gt_label' in sample_item: + gt_labels = [sample.gt_label for sample in data_samples] + batch_label, label_indices = cat_batch_labels(gt_labels) + batch_label = batch_label.to(self.device) + if 'gt_score' in sample_item: + gt_scores = [sample.gt_score for sample in data_samples] + batch_score = torch.stack(gt_scores).to(self.device) + elif self.to_onehot and 'gt_label' in sample_item: + assert batch_label is not None, \ + 'Cannot generate onehot format labels because no labels.' + num_classes = self.num_classes or sample_item.get( + 'num_classes') + assert num_classes is not None, \ + 'Cannot generate one-hot format labels because not set ' \ + '`num_classes` in `data_preprocessor`.' + batch_score = batch_label_to_onehot( + batch_label, label_indices, num_classes).to(self.device) + + # ----- Batch Augmentations ---- + if (training and self.batch_augments is not None + and batch_score is not None): + inputs, batch_score = self.batch_augments(inputs, batch_score) + + # ----- scatter labels and scores to data samples --- + if batch_label is not None: + for sample, label in zip( + data_samples, tensor_split(batch_label, + label_indices)): + sample.set_gt_label(label) + if batch_score is not None: + for sample, score in zip(data_samples, batch_score): + sample.set_gt_score(score) + elif isinstance(sample_item, MultiTaskDataSample): + data_samples = self.cast_data(data_samples) + + return {'inputs': inputs, 'data_samples': data_samples} + + +@MODELS.register_module() +class SelfSupDataPreprocessor(ImgDataPreprocessor): + """Image pre-processor for operations, like normalization and bgr to rgb. + + Compared with the :class:`mmengine.ImgDataPreprocessor`, this module + supports ``inputs`` as torch.Tensor or a list of torch.Tensor. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + non_blocking: Optional[bool] = False): + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + bgr_to_rgb=bgr_to_rgb, + rgb_to_bgr=rgb_to_bgr, + non_blocking=non_blocking) + + self._channel_conversion = to_rgb or bgr_to_rgb or rgb_to_bgr + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + assert isinstance(data, + dict), 'Please use default_collate in dataloader, \ + instead of pseudo_collate.' + + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + + # Here is what is different from :class:`mmengine.ImgDataPreprocessor` + # Since there are multiple views for an image for some algorithms, + # e.g. SimCLR, each item in inputs is a list, containing multi-views + # for an image. + if isinstance(batch_inputs, list): + # channel transform + if self._channel_conversion: + batch_inputs = [ + _input[:, [2, 1, 0], ...] for _input in batch_inputs + ] + + # convert to float after channel conversion to ensure efficiency + batch_inputs = [_input.float() for _input in batch_inputs] + + # normalization. + if self._enable_normalize: + batch_inputs = [(_input - self.mean) / self.std + for _input in batch_inputs] + else: + # channel transform + if self._channel_conversion: + batch_inputs = batch_inputs[:, [2, 1, 0], ...] + + # convert to float after channel conversion to ensure efficiency + batch_inputs = batch_inputs.float() + + # normalization. + if self._enable_normalize: + batch_inputs = (batch_inputs - self.mean) / self.std + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} + + +@MODELS.register_module() +class TwoNormDataPreprocessor(SelfSupDataPreprocessor): + """Image pre-processor for CAE, BEiT v1/v2, etc. + + Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module + will normalize the prediction image and target image with different + normalization parameters. + + Args: + mean (Sequence[float or int], optional): The pixel mean of image + channels. If ``to_rgb=True`` it means the mean value of R, G, B + channels. If the length of `mean` is 1, it means all channels have + the same mean value, or the input is a gray image. If it is not + specified, images will not be normalized. Defaults to None. + std (Sequence[float or int], optional): The pixel standard deviation of + image channels. If ``to_rgb=True`` it means the standard deviation + of R, G, B channels. If the length of `std` is 1, it means all + channels have the same standard deviation, or the input is a gray + image. If it is not specified, images will not be normalized. + Defaults to None. + second_mean (Sequence[float or int], optional): The description is + like ``mean``, it can be customized for targe image. Defaults to + None. + second_std (Sequence[float or int], optional): The description is + like ``std``, it can be customized for targe image. Defaults to + None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + non_blocking (bool): Whether block current process when transferring + data to device. Defaults to False. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + second_mean: Sequence[Union[float, int]] = None, + second_std: Sequence[Union[float, int]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + non_blocking: Optional[bool] = False): + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + to_rgb=to_rgb, + non_blocking=non_blocking) + assert (second_mean is not None) and (second_std is not None), ( + 'mean and std should not be None while using ' + '`TwoNormDataPreprocessor`') + assert len(second_mean) == 3 or len(second_mean) == 1, ( + '`mean` should have 1 or 3 values, to be compatible with ' + f'RGB or gray image, but got {len(second_mean)} values') + assert len(second_std) == 3 or len(second_std) == 1, ( + '`std` should have 1 or 3 values, to be compatible with RGB ' + f'or gray image, but got {len(std)} values') + + self.register_buffer('second_mean', + torch.tensor(second_mean).view(-1, 1, 1), False) + self.register_buffer('second_std', + torch.tensor(second_std).view(-1, 1, 1), False) + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization and bgr2rgb conversion based on + ``BaseDataPreprocessor``. The ``batch_inputs`` in forward function is a + list. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + # channel transform + if self._channel_conversion: + batch_inputs = [ + _input[:, [2, 1, 0], ...] for _input in batch_inputs + ] + + # convert to float after channel conversion to ensure efficiency + batch_inputs = [_input.float() for _input in batch_inputs] + + # Normalization. Here is what is different from + # :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target + # image and prediction image with different normalization params + if self._enable_normalize: + batch_inputs = [ + (batch_inputs[0] - self.mean) / self.std, + (batch_inputs[1] - self.second_mean) / self.second_std + ] + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} + + +@MODELS.register_module() +class VideoDataPreprocessor(BaseDataPreprocessor): + """Video pre-processor for operations, like normalization and bgr to rgb + conversion . + + Compared with the :class:`mmaction.ActionDataPreprocessor`, this module + supports ``inputs`` as torch.Tensor or a list of torch.Tensor. + + Args: + mean (Sequence[float or int, optional): The pixel mean of channels + of images or stacked optical flow. Defaults to None. + std (Sequence[float or int], optional): The pixel standard deviation + of channels of images or stacked optical flow. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + to_rgb (bool): Whether to convert image from BGR to RGB. + Defaults to False. + format_shape (str): Format shape of input data. + Defaults to ``'NCHW'``. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + format_shape: str = 'NCHW') -> None: + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + self.format_shape = format_shape + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both ' \ + '`mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + if self.format_shape == 'NCHW': + normalizer_shape = (-1, 1, 1) + elif self.format_shape == 'NCTHW': + normalizer_shape = (-1, 1, 1, 1) + else: + raise ValueError(f'Invalid format shape: {format_shape}') + + self.register_buffer( + 'mean', + torch.tensor(mean, dtype=torch.float32).view(normalizer_shape), + False) + self.register_buffer( + 'std', + torch.tensor(std, dtype=torch.float32).view(normalizer_shape), + False) + else: + self._enable_normalize = False + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[List[torch.Tensor], Optional[list]]: Data in the same format + as the model input. + """ + + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + + if isinstance(batch_inputs, list): + # channel transform + if self.to_rgb: + if self.format_shape == 'NCHW': + batch_inputs = [ + _input[..., [2, 1, 0], :, :] for _input in batch_inputs + ] + elif self.format_shape == 'NCTHW': + batch_inputs = [ + _input[..., [2, 1, 0], :, :, :] + for _input in batch_inputs + ] + else: + raise ValueError( + f'Invalid format shape: {self.format_shape}') + + # convert to float after channel conversion to ensure efficiency + batch_inputs = [_input.float() for _input in batch_inputs] + + # normalization + if self._enable_normalize: + batch_inputs = [(_input - self.mean) / self.std + for _input in batch_inputs] + + else: + # channel transform + if self.to_rgb: + if self.format_shape == 'NCHW': + batch_inputs = batch_inputs[..., [2, 1, 0], :, :] + elif self.format_shape == 'NCTHW': + batch_inputs = batch_inputs[..., [2, 1, 0], :, :, :] + else: + raise ValueError( + f'Invalid format shape: {self.format_shape}') + + # convert to float after channel conversion to ensure efficiency + batch_inputs = batch_inputs.float() + + # normalization + if self._enable_normalize: + batch_inputs = (batch_inputs - self.mean) / self.std + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} + + +@MODELS.register_module() +class MultiModalDataPreprocessor(BaseDataPreprocessor): + """Data pre-processor for image-text multimodality tasks. + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (Number): The padded pixel value. Defaults to 0. + to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + """ + + def __init__( + self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Number = 0, + to_rgb: bool = False, + ): + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both `mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + self.register_buffer('mean', + torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(-1, 1, 1), False) + else: + self._enable_normalize = False + + def forward(self, data: dict, training: bool = False) -> dict: + """Perform normalization, padding, bgr2rgb conversion and batch + augmentation based on ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + dict: Data in the same format as the model input. + """ + data = self.cast_data(data) + + imgs = data.get('inputs', None) + + def _process_img(img): + # ------ To RGB ------ + if self.to_rgb and img.size(1) == 3: + img = img.flip(1) + + # -- Normalization --- + img = img.float() + if self._enable_normalize: + img = (img - self.mean) / self.std + + # ------ Padding ----- + if self.pad_size_divisor > 1: + h, w = img.shape[-2:] + + target_h = math.ceil( + h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil( + w / self.pad_size_divisor) * self.pad_size_divisor + pad_h = target_h - h + pad_w = target_w - w + img = F.pad(img, (0, pad_w, 0, pad_h), 'constant', + self.pad_value) + return img + + if isinstance(imgs, torch.Tensor): + imgs = _process_img(imgs) + elif isinstance(imgs, Sequence): + # B, T, C, H, W + imgs = torch.stack([_process_img(img) for img in imgs], dim=1) + elif imgs is not None: + raise ValueError(f'{type(imgs)} is not supported for imgs inputs.') + + data_samples = data.get('data_samples', None) + + return {'images': imgs, 'data_samples': data_samples} diff --git a/mmpretrain/models/utils/ema.py b/mmpretrain/models/utils/ema.py new file mode 100644 index 0000000..63c5006 --- /dev/null +++ b/mmpretrain/models/utils/ema.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from math import cos, pi +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.logging import MessageHub +from mmengine.model import ExponentialMovingAverage + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CosineEMA(ExponentialMovingAverage): + r"""CosineEMA is implemented for updating momentum parameter, used in BYOL, + MoCoV3, etc. + + All parameters are updated by the formula as below: + + .. math:: + + X'_{t+1} = (1 - m) * X'_t + m * X_t + + Where :math:`m` the the momentum parameter. And it's updated with cosine + annealing, including momentum adjustment following: + + .. math:: + m = m_{end} + (m_{end} - m_{start}) * (\cos\frac{k\pi}{K} + 1) / 2 + + where :math:`k` is the current step, :math:`K` is the total steps. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, + :math:`X'_{t}` is the moving average and :math:`X_t` is the new + observed value. The value of momentum is usually a small number, + allowing observed values to slowly update the ema parameters. See also + :external:py:class:`torch.nn.BatchNorm2d`. + + Args: + model (nn.Module): The model to be averaged. + momentum (float): The start momentum value. Defaults to 0.004. + end_momentum (float): The end momentum value for cosine annealing. + Defaults to 0. + interval (int): Interval between two updates. Defaults to 1. + device (torch.device, optional): If provided, the averaged model will + be stored on the :attr:`device`. Defaults to None. + update_buffers (bool): if True, it will compute running averages for + both the parameters and the buffers of the model. Defaults to + False. + """ + + def __init__(self, + model: nn.Module, + momentum: float = 0.004, + end_momentum: float = 0., + interval: int = 1, + device: Optional[torch.device] = None, + update_buffers: bool = False) -> None: + super().__init__( + model=model, + momentum=momentum, + interval=interval, + device=device, + update_buffers=update_buffers) + self.end_momentum = end_momentum + + def avg_func(self, averaged_param: torch.Tensor, + source_param: torch.Tensor, steps: int) -> None: + """Compute the moving average of the parameters using the cosine + momentum strategy. + + Args: + averaged_param (Tensor): The averaged parameters. + source_param (Tensor): The source parameters. + steps (int): The number of times the parameters have been + updated. + + Returns: + Tensor: The averaged parameters. + """ + message_hub = MessageHub.get_current_instance() + max_iters = message_hub.get_info('max_iters') + cosine_annealing = (cos(pi * steps / float(max_iters)) + 1) / 2 + momentum = self.end_momentum - (self.end_momentum - + self.momentum) * cosine_annealing + averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) diff --git a/mmpretrain/models/utils/embed.py b/mmpretrain/models/utils/embed.py new file mode 100644 index 0000000..8299f9a --- /dev/null +++ b/mmpretrain/models/utils/embed.py @@ -0,0 +1,423 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import AdaptivePadding +from mmengine.model import BaseModule + +from .helpers import to_2tuple + + +def resize_pos_embed(pos_embed, + src_shape, + dst_shape, + mode='bicubic', + num_extra_tokens=1): + """Resize pos_embed weights. + + Args: + pos_embed (torch.Tensor): Position embedding weights with shape + [1, L, C]. + src_shape (tuple): The resolution of downsampled origin training + image, in format (H, W). + dst_shape (tuple): The resolution of downsampled new training + image, in format (H, W). + mode (str): Algorithm used for upsampling. Choose one from 'nearest', + 'linear', 'bilinear', 'bicubic' and 'trilinear'. + Defaults to 'bicubic'. + num_extra_tokens (int): The number of extra tokens, such as cls_token. + Defaults to 1. + + Returns: + torch.Tensor: The resized pos_embed of shape [1, L_new, C] + """ + if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: + return pos_embed + assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' + _, L, C = pos_embed.shape + src_h, src_w = src_shape + assert L == src_h * src_w + num_extra_tokens, \ + f"The length of `pos_embed` ({L}) doesn't match the expected " \ + f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \ + '`img_size` argument.' + extra_tokens = pos_embed[:, :num_extra_tokens] + + src_weight = pos_embed[:, num_extra_tokens:] + src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) + + # The cubic interpolate algorithm only accepts float32 + dst_weight = F.interpolate( + src_weight.float(), size=dst_shape, align_corners=False, mode=mode) + dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) + dst_weight = dst_weight.to(src_weight.dtype) + + return torch.cat((extra_tokens, dst_weight), dim=1) + + +def resize_relative_position_bias_table(src_shape, dst_shape, table, num_head): + """Resize relative position bias table. + + Args: + src_shape (int): The resolution of downsampled origin training + image, in format (H, W). + dst_shape (int): The resolution of downsampled new training + image, in format (H, W). + table (tensor): The relative position bias of the pretrained model. + num_head (int): Number of attention heads. + + Returns: + torch.Tensor: The resized relative position bias table. + """ + from scipy import interpolate + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_shape // 2) + if gp > dst_shape // 2: + right = q + else: + left = q + + dis = [] + cur = 1 + for i in range(src_shape // 2): + dis.append(cur) + cur += q**(i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_shape // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + all_rel_pos_bias = [] + + for i in range(num_head): + z = table[:, i].view(src_shape, src_shape).float().numpy() + f_cubic = interpolate.interp2d(x, y, z, kind='cubic') + all_rel_pos_bias.append( + torch.Tensor(f_cubic(dx, + dy)).contiguous().view(-1, + 1).to(table.device)) + new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + return new_rel_pos_bias + + +class PatchEmbed(BaseModule): + """Image to Patch Embedding. + + We use a conv layer to implement PatchEmbed. + + Args: + img_size (int | tuple): The size of input image. Default: 224 + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None + conv_cfg (dict, optional): The config dict for conv layers. + Default: None + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None + """ + + def __init__(self, + img_size=224, + in_channels=3, + embed_dims=768, + norm_cfg=None, + conv_cfg=None, + init_cfg=None): + super(PatchEmbed, self).__init__(init_cfg) + warnings.warn('The `PatchEmbed` in mmpretrain will be deprecated. ' + 'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. ' + "It's more general and supports dynamic input shape") + + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + self.img_size = img_size + self.embed_dims = embed_dims + + # Use conv layer to embed + conv_cfg = conv_cfg or dict() + _conv_cfg = dict( + type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1) + _conv_cfg.update(conv_cfg) + self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims) + + # Calculate how many patches a input image is splited to. + h_out, w_out = [(self.img_size[i] + 2 * self.projection.padding[i] - + self.projection.dilation[i] * + (self.projection.kernel_size[i] - 1) - 1) // + self.projection.stride[i] + 1 for i in range(2)] + + self.patches_resolution = (h_out, w_out) + self.num_patches = h_out * w_out + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't " \ + f'match model ({self.img_size[0]}*{self.img_size[1]}).' + # The output size is (B, N, D), where N=H*W/P/P, D is embid_dim + x = self.projection(x).flatten(2).transpose(1, 2) + + if self.norm is not None: + x = self.norm(x) + + return x + + +# Modified from pytorch-image-models +class HybridEmbed(BaseModule): + """CNN Feature Map Embedding. + + Extract feature map from CNN, flatten, + project to embedding dim. + + Args: + backbone (nn.Module): CNN backbone + img_size (int | tuple): The size of input image. Default: 224 + feature_size (int | tuple, optional): Size of feature map extracted by + CNN backbone. Default: None + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_cfg (dict, optional): The config dict for conv layers. + Default: None. + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + backbone, + img_size=224, + feature_size=None, + in_channels=3, + embed_dims=768, + conv_cfg=None, + init_cfg=None): + super(HybridEmbed, self).__init__(init_cfg) + assert isinstance(backbone, nn.Module) + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of + # determining the exact dim of the output feature + # map for all networks, the feature metadata has + # reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of + # each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone( + torch.zeros(1, in_channels, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + # last feature if backbone outputs list/tuple of features + o = o[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + self.num_patches = feature_size[0] * feature_size[1] + + # Use conv layer to embed + conv_cfg = conv_cfg or dict() + _conv_cfg = dict( + type='Conv2d', kernel_size=1, stride=1, padding=0, dilation=1) + _conv_cfg.update(conv_cfg) + self.projection = build_conv_layer(_conv_cfg, feature_dim, embed_dims) + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + # last feature if backbone outputs list/tuple of features + x = x[-1] + x = self.projection(x).flatten(2).transpose(1, 2) + return x + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + Modified from mmcv, and this module supports specifying whether to use + post-norm. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map ((used in Swin Transformer)). Our + implementation uses :class:`torch.nn.Unfold` to merge patches, which is + about 25% faster than the original implementation. However, we need to + modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. To gets fully covered + by filter and stride you specified. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Defaults to None, which means to be set as + ``kernel_size``. + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Defaults to "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Defaults to 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults to False. + norm_cfg (dict, optional): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + use_post_norm (bool): Whether to use post normalization here. + Defaults to False. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding='corner', + dilation=1, + bias=False, + norm_cfg=dict(type='LN'), + use_post_norm=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.use_post_norm = use_post_norm + + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adaptive_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adaptive_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + if norm_cfg is not None: + # build pre or post norm layer based on different channels + if self.use_post_norm: + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + else: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + + if self.adaptive_padding: + x = self.adaptive_padding(x) + H, W = x.shape[-2:] + + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + x = self.sampler(x) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + + if self.use_post_norm: + # use post-norm here + x = self.reduction(x) + x = self.norm(x) if self.norm else x + else: + x = self.norm(x) if self.norm else x + x = self.reduction(x) + + return x, output_size diff --git a/mmpretrain/models/utils/helpers.py b/mmpretrain/models/utils/helpers.py new file mode 100644 index 0000000..971f450 --- /dev/null +++ b/mmpretrain/models/utils/helpers.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections.abc +import warnings +from itertools import repeat + +import torch +from mmengine.utils import digit_version + + +def is_tracing() -> bool: + """Determine whether the model is called during the tracing of code with + ``torch.jit.trace``.""" + if digit_version(torch.__version__) >= digit_version('1.6.0'): + on_trace = torch.jit.is_tracing() + # In PyTorch 1.6, torch.jit.is_tracing has a bug. + # Refers to https://github.com/pytorch/pytorch/issues/42448 + if isinstance(on_trace, bool): + return on_trace + else: + return torch._C._is_tracing() + else: + warnings.warn( + 'torch.jit.is_tracing is only supported after v1.6.0. ' + 'Therefore is_tracing returns False automatically. Please ' + 'set on_trace manually if you are using trace.', UserWarning) + return False + + +# From PyTorch internals +def _ntuple(n): + """A `to_tuple` function generator. + + It returns a function, this function will repeat the input to a tuple of + length ``n`` if the input is not an Iterable object, otherwise, return the + input directly. + + Args: + n (int): The number of the target length. + """ + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/mmpretrain/models/utils/huggingface.py b/mmpretrain/models/utils/huggingface.py new file mode 100644 index 0000000..a44d6da --- /dev/null +++ b/mmpretrain/models/utils/huggingface.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import contextlib +from typing import Optional + +import transformers +from mmengine.registry import Registry +from transformers import AutoConfig, PreTrainedModel +from transformers.models.auto.auto_factory import _BaseAutoModelClass + +from mmpretrain.registry import MODELS, TOKENIZER + + +def register_hf_tokenizer( + cls: Optional[type] = None, + registry: Registry = TOKENIZER, +): + """Register HuggingFace-style PreTrainedTokenizerBase class.""" + if cls is None: + + # use it as a decorator: @register_hf_tokenizer() + def _register(cls): + register_hf_tokenizer(cls=cls) + return cls + + return _register + + def from_pretrained(**kwargs): + if ('pretrained_model_name_or_path' not in kwargs + and 'name_or_path' not in kwargs): + raise TypeError( + f'{cls.__name__}.from_pretrained() missing required ' + "argument 'pretrained_model_name_or_path' or 'name_or_path'.") + # `pretrained_model_name_or_path` is too long for config, + # add an alias name `name_or_path` here. + name_or_path = kwargs.pop('pretrained_model_name_or_path', + kwargs.pop('name_or_path')) + return cls.from_pretrained(name_or_path, **kwargs) + + registry._register_module(module=from_pretrained, module_name=cls.__name__) + return cls + + +_load_hf_pretrained_model = True + + +@contextlib.contextmanager +def no_load_hf_pretrained_model(): + global _load_hf_pretrained_model + _load_hf_pretrained_model = False + yield + _load_hf_pretrained_model = True + + +def register_hf_model( + cls: Optional[type] = None, + registry: Registry = MODELS, +): + """Register HuggingFace-style PreTrainedModel class.""" + if cls is None: + + # use it as a decorator: @register_hf_tokenizer() + def _register(cls): + register_hf_model(cls=cls) + return cls + + return _register + + if issubclass(cls, _BaseAutoModelClass): + get_config = AutoConfig.from_pretrained + from_config = cls.from_config + elif issubclass(cls, PreTrainedModel): + get_config = cls.config_class.from_pretrained + from_config = cls + else: + raise TypeError('Not auto model nor pretrained model of huggingface.') + + def build(**kwargs): + if ('pretrained_model_name_or_path' not in kwargs + and 'name_or_path' not in kwargs): + raise TypeError( + f'{cls.__name__} missing required argument ' + '`pretrained_model_name_or_path` or `name_or_path`.') + # `pretrained_model_name_or_path` is too long for config, + # add an alias name `name_or_path` here. + name_or_path = kwargs.pop('pretrained_model_name_or_path', + kwargs.pop('name_or_path')) + + if kwargs.pop('load_pretrained', True) and _load_hf_pretrained_model: + model = cls.from_pretrained(name_or_path, **kwargs) + setattr(model, 'is_init', True) + return model + else: + cfg = get_config(name_or_path, **kwargs) + return from_config(cfg) + + registry._register_module(module=build, module_name=cls.__name__) + return cls + + +register_hf_model(transformers.AutoModelForCausalLM) diff --git a/mmpretrain/models/utils/inverted_residual.py b/mmpretrain/models/utils/inverted_residual.py new file mode 100644 index 0000000..8387b21 --- /dev/null +++ b/mmpretrain/models/utils/inverted_residual.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule + +from .se_layer import SELayer + + +class InvertedResidual(BaseModule): + """Inverted Residual Block. + + Args: + in_channels (int): The input channels of this module. + out_channels (int): The output channels of this module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Defaults to 3. + stride (int): The stride of the depthwise convolution. Defaults to 1. + se_cfg (dict, optional): Config dict for se layer. Defaults to None, + which means no se layer. + conv_cfg (dict): Config dict for convolution layer. Defaults to None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='BN')``. + act_cfg (dict): Config dict for activation layer. + Defaults to ``dict(type='ReLU')``. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict | list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + drop_path_rate=0., + with_cp=False, + init_cfg=None): + super(InvertedResidual, self).__init__(init_cfg) + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.with_se = se_cfg is not None + self.with_expand_conv = (mid_channels != in_channels) + + if self.with_se: + assert isinstance(se_cfg, dict) + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if self.with_se: + self.se = SELayer(**se_cfg) + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + """Forward function. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + self.drop_path(out) + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/mmpretrain/models/utils/layer_scale.py b/mmpretrain/models/utils/layer_scale.py new file mode 100644 index 0000000..bb480a1 --- /dev/null +++ b/mmpretrain/models/utils/layer_scale.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +import torch.nn as nn + + +class LayerScale(nn.Module): + """LayerScale layer. + + Args: + dim (int): Dimension of input features. + layer_scale_init_value (float or torch.Tensor): Init value of layer + scale. Defaults to 1e-5. + inplace (bool): inplace: can optionally do the + operation in-place. Defaults to False. + data_format (str): The input data format, could be 'channels_last' + or 'channels_first', representing (B, C, H, W) and + (B, N, C) format data respectively. Defaults to 'channels_last'. + """ + + def __init__(self, + dim: int, + layer_scale_init_value: Union[float, torch.Tensor] = 1e-5, + inplace: bool = False, + data_format: str = 'channels_last'): + super().__init__() + assert data_format in ('channels_last', 'channels_first'), \ + "'data_format' could only be channels_last or channels_first." + self.inplace = inplace + self.data_format = data_format + self.weight = nn.Parameter(torch.ones(dim) * layer_scale_init_value) + + def forward(self, x): + if self.data_format == 'channels_first': + if self.inplace: + return x.mul_(self.weight.view(-1, 1, 1)) + else: + return x * self.weight.view(-1, 1, 1) + return x.mul_(self.weight) if self.inplace else x * self.weight diff --git a/mmpretrain/models/utils/make_divisible.py b/mmpretrain/models/utils/make_divisible.py new file mode 100644 index 0000000..1ec7468 --- /dev/null +++ b/mmpretrain/models/utils/make_divisible.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number down to the nearest value that can + be divisible by the divisor. + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int, optional): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel + number to the original channel number. Default: 0.9. + Returns: + int: The modified output channel number + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/mmpretrain/models/utils/norm.py b/mmpretrain/models/utils/norm.py new file mode 100644 index 0000000..8b890a0 --- /dev/null +++ b/mmpretrain/models/utils/norm.py @@ -0,0 +1,133 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class GRN(nn.Module): + """Global Response Normalization Module. + + Come from `ConvNeXt V2: Co-designing and Scaling ConvNets with Masked + Autoencoders `_ + + Args: + in_channels (int): The number of channels of the input tensor. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-6. + """ + + def __init__(self, in_channels, eps=1e-6): + super().__init__() + self.in_channels = in_channels + self.gamma = nn.Parameter(torch.zeros(in_channels)) + self.beta = nn.Parameter(torch.zeros(in_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor, data_format='channel_first'): + """Forward method. + + Args: + x (torch.Tensor): The input tensor. + data_format (str): The format of the input tensor. If + ``"channel_first"``, the shape of the input tensor should be + (B, C, H, W). If ``"channel_last"``, the shape of the input + tensor should be (B, H, W, C). Defaults to "channel_first". + """ + if data_format == 'channel_last': + gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + self.eps) + x = self.gamma * (x * nx) + self.beta + x + elif data_format == 'channel_first': + gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) + nx = gx / (gx.mean(dim=1, keepdim=True) + self.eps) + x = self.gamma.view(1, -1, 1, 1) * (x * nx) + self.beta.view( + 1, -1, 1, 1) + x + return x + + +@MODELS.register_module('LN2d') +class LayerNorm2d(nn.LayerNorm): + """LayerNorm on channels for 2d images. + + Args: + num_channels (int): The number of channels of the input tensor. + eps (float): a value added to the denominator for numerical stability. + Defaults to 1e-5. + elementwise_affine (bool): a boolean value that when set to ``True``, + this module has learnable per-element affine parameters initialized + to ones (for weights) and zeros (for biases). Defaults to True. + """ + + def __init__(self, num_channels: int, **kwargs) -> None: + super().__init__(num_channels, **kwargs) + self.num_channels = self.normalized_shape[0] + + def forward(self, x, data_format='channel_first'): + """Forward method. + + Args: + x (torch.Tensor): The input tensor. + data_format (str): The format of the input tensor. If + ``"channel_first"``, the shape of the input tensor should be + (B, C, H, W). If ``"channel_last"``, the shape of the input + tensor should be (B, H, W, C). Defaults to "channel_first". + """ + assert x.dim() == 4, 'LayerNorm2d only supports inputs with shape ' \ + f'(N, C, H, W), but got tensor with shape {x.shape}' + if data_format == 'channel_last': + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, + self.eps) + elif data_format == 'channel_first': + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, + self.eps) + # If the output is discontiguous, it may cause some unexpected + # problem in the downstream tasks + x = x.permute(0, 3, 1, 2).contiguous() + return x + + +def build_norm_layer(cfg: dict, num_features: int) -> nn.Module: + """Build normalization layer. + + Args: + cfg (dict): The norm layer config, which should contain: + + - type (str): Layer type. + - layer args: Args needed to instantiate a norm layer. + + num_features (int): Number of input channels. + + Returns: + nn.Module: The created norm layer. + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + norm_layer = MODELS.get(layer_type) + if norm_layer is None: + raise KeyError(f'Cannot find {layer_type} in registry under scope ' + f'name {MODELS.scope}') + + requires_grad = cfg_.pop('requires_grad', True) + cfg_.setdefault('eps', 1e-5) + + if layer_type != 'GN': + layer = norm_layer(num_features, **cfg_) + else: + layer = norm_layer(num_channels=num_features, **cfg_) + + if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): + layer._specify_ddp_gpu_num(1) + + for param in layer.parameters(): + param.requires_grad = requires_grad + + return layer diff --git a/mmpretrain/models/utils/position_encoding.py b/mmpretrain/models/utils/position_encoding.py new file mode 100644 index 0000000..07a3c48 --- /dev/null +++ b/mmpretrain/models/utils/position_encoding.py @@ -0,0 +1,247 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from functools import partial +from typing import Optional, Sequence, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule +from mmengine.utils import digit_version + +from ..utils import to_2tuple + +# After pytorch v1.10.0, use torch.meshgrid without indexing +# will raise extra warning. For more details, +# refers to https://github.com/pytorch/pytorch/issues/50276 +if digit_version(torch.__version__) >= digit_version('1.10.0'): + torch_meshgrid = partial(torch.meshgrid, indexing='ij') +else: + torch_meshgrid = torch.meshgrid + + +class ConditionalPositionEncoding(BaseModule): + """The Conditional Position Encoding (CPE) module. + + The CPE is the implementation of 'Conditional Positional Encodings + for Vision Transformers '_. + + Args: + in_channels (int): Number of input channels. + embed_dims (int): The feature dimension. Default: 768. + stride (int): Stride of conv layer. Default: 1. + """ + + def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): + super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg) + self.proj = nn.Conv2d( + in_channels, + embed_dims, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + groups=embed_dims) + self.stride = stride + + def forward(self, x, hw_shape): + B, N, C = x.shape + H, W = hw_shape + feat_token = x + # convert (B, N, C) to (B, C, H, W) + cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W).contiguous() + if self.stride == 1: + x = self.proj(cnn_feat) + cnn_feat + else: + x = self.proj(cnn_feat) + x = x.flatten(2).transpose(1, 2) + return x + + +class PositionEncodingFourier(BaseModule): + """The Position Encoding Fourier (PEF) module. + + The PEF is adopted from EdgeNeXt '_. + Args: + in_channels (int): Number of input channels. + Default: 32 + embed_dims (int): The feature dimension. + Default: 768. + temperature (int): Temperature. + Default: 10000. + dtype (torch.dtype): The data type. + Default: torch.float32. + init_cfg (dict): The config dict for initializing the module. + Default: None. + """ + + def __init__(self, + in_channels=32, + embed_dims=768, + temperature=10000, + dtype=torch.float32, + init_cfg=None): + super(PositionEncodingFourier, self).__init__(init_cfg=init_cfg) + self.proj = nn.Conv2d(in_channels * 2, embed_dims, kernel_size=1) + self.scale = 2 * math.pi + self.in_channels = in_channels + self.embed_dims = embed_dims + self.dtype = dtype + + if digit_version(torch.__version__) < digit_version('1.8.0'): + floor_div = torch.floor_divide + else: + floor_div = partial(torch.div, rounding_mode='floor') + dim_t = torch.arange(in_channels, dtype=self.dtype) + self.dim_t = temperature**(2 * floor_div(dim_t, 2) / in_channels) + + def forward(self, bhw_shape): + B, H, W = bhw_shape + mask = torch.zeros(B, H, W).bool().to(self.proj.weight.device) + not_mask = ~mask + eps = 1e-6 + y_embed = not_mask.cumsum(1, dtype=self.dtype) + x_embed = not_mask.cumsum(2, dtype=self.dtype) + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = self.dim_t.to(mask.device) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.proj(pos) + + return pos + + +def build_2d_sincos_position_embedding( + patches_resolution: Union[int, Sequence[int]], + embed_dims: int, + temperature: Optional[int] = 10000., + cls_token: Optional[bool] = False) -> torch.Tensor: + """The function is to build position embedding for model to obtain the + position information of the image patches. + + Args: + patches_resolution (Union[int, Sequence[int]]): The resolution of each + patch. + embed_dims (int): The dimension of the embedding vector. + temperature (int, optional): The temperature parameter. Defaults to + 10000. + cls_token (bool, optional): Whether to concatenate class token. + Defaults to False. + + Returns: + torch.Tensor: The position embedding vector. + """ + + if isinstance(patches_resolution, int): + patches_resolution = (patches_resolution, patches_resolution) + + h, w = patches_resolution + grid_w = torch.arange(w, dtype=torch.float32) + grid_h = torch.arange(h, dtype=torch.float32) + grid_w, grid_h = torch_meshgrid(grid_w, grid_h) + assert embed_dims % 4 == 0, \ + 'Embed dimension must be divisible by 4.' + pos_dim = embed_dims // 4 + + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1. / (temperature**omega) + out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) + out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) + + pos_emb = torch.cat( + [ + torch.sin(out_w), + torch.cos(out_w), + torch.sin(out_h), + torch.cos(out_h) + ], + dim=1, + )[None, :, :] + + if cls_token: + cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32) + pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) + + return pos_emb + + +class RotaryEmbeddingFast(BaseModule): + """Implements 2D rotary embedding (RoPE) for image tokens. Position + encoding is implemented with sin and cos functions, + + .. math:: + Pos_{cos} = cos(\frac{t}{\theta^{\frac{2i}{d}}} \\ + Pos_{sin} = sin(\frac{t}{\theta^{\frac{2i}{d}}} + Args: + embed_dims (int): The feature dimension for each head. + patch_resolution (int | tuple): The resolution of the + image, in format (H, W). + theta (float): The hyperparameter for position coding. + Defaults to 10000. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + patch_resolution, + theta=10000., + init_cfg=None): + super(RotaryEmbeddingFast, self).__init__(init_cfg=init_cfg) + + self.half_dim = embed_dims // 2 + self.patch_resolution = to_2tuple(patch_resolution) + self.theta = theta + + freqs_cos, freqs_sin = self.compute_position_embedding() + self.register_buffer('freqs_cos', freqs_cos) + self.register_buffer('freqs_sin', freqs_sin) + + def compute_position_embedding(self): + frequency = self.theta**( + torch.arange(0, self.half_dim, 2).float() / self.half_dim) + frequency = 1. / frequency + + h, w = self.patch_resolution + th = torch.arange(h) / h * self.half_dim + tw = torch.arange(w) / w * self.half_dim + + position_h = (th[:, None] @ frequency[None, :]).repeat(1, 2) + position_w = (tw[:, None] @ frequency[None, :]).repeat(1, 2) + + height = position_h[:, None, :].expand(h, w, self.half_dim) + width = position_w[None, :, :].expand(h, w, self.half_dim) + position = torch.cat((height, width), dim=-1) + + freqs_cos = position.cos().view(-1, position.shape[-1]) + freqs_sin = position.sin().view(-1, position.shape[-1]) + + return freqs_cos, freqs_sin + + def forward(self, x, patch_resolution): + # Check whether the patch resolution is the predefined size + patch_resolution = to_2tuple(patch_resolution) + if patch_resolution != self.patch_resolution: + self.patch_resolution = patch_resolution + freqs_cos, freqs_sin = self.compute_position_embedding() + self.register_buffer('freqs_cos', freqs_cos.to(x.device)) + self.register_buffer('freqs_sin', freqs_sin.to(x.device)) + + batch, num_heads, num_patches, dim = x.shape + + inputs = x + x = x.reshape(batch, num_heads, num_patches, -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + x = x.reshape(batch, num_heads, num_patches, dim) + + return inputs * self.freqs_cos + x * self.freqs_sin diff --git a/mmpretrain/models/utils/res_layer_extra_norm.py b/mmpretrain/models/utils/res_layer_extra_norm.py new file mode 100644 index 0000000..37e387b --- /dev/null +++ b/mmpretrain/models/utils/res_layer_extra_norm.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .norm import build_norm_layer + +try: + from mmdet.models.backbones import ResNet + from mmdet.models.roi_heads.shared_heads.res_layer import ResLayer + from mmdet.registry import MODELS + + @MODELS.register_module() + class ResLayerExtraNorm(ResLayer): + """Add extra norm to original ``ResLayer``.""" + + def __init__(self, *args, **kwargs): + super(ResLayerExtraNorm, self).__init__(*args, **kwargs) + + block = ResNet.arch_settings[kwargs['depth']][0] + self.add_module( + 'norm', + build_norm_layer(self.norm_cfg, + 64 * 2**self.stage * block.expansion)) + + def forward(self, x): + """Forward function.""" + res_layer = getattr(self, f'layer{self.stage + 1}') + norm = getattr(self, 'norm') + x = res_layer(x) + out = norm(x) + return out + +except ImportError: + ResLayerExtraNorm = None diff --git a/mmpretrain/models/utils/se_layer.py b/mmpretrain/models/utils/se_layer.py new file mode 100644 index 0000000..2029017 --- /dev/null +++ b/mmpretrain/models/utils/se_layer.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.utils import is_tuple_of + +from .make_divisible import make_divisible + + +class SELayer(BaseModule): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + squeeze_channels (None or int): The intermediate channel number of + SElayer. Default: None, means the value of ``squeeze_channels`` + is ``make_divisible(channels // ratio, divisor)``. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will + be ``make_divisible(channels // ratio, divisor)``. Only used when + ``squeeze_channels`` is None. Default: 16. + divisor(int): The divisor to true divide the channel number. Only + used when ``squeeze_channels`` is None. Default: 8. + conv_cfg (None or dict): Config dict for convolution layer. Default: + None, which means using conv2d. + return_weight(bool): Whether to return the weight. Default: False. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configurated + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configurated by the first dict and the + second activation layer will be configurated by the second dict. + Default: (dict(type='ReLU'), dict(type='Sigmoid')) + """ + + def __init__(self, + channels, + squeeze_channels=None, + ratio=16, + divisor=8, + bias='auto', + conv_cfg=None, + act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')), + return_weight=False, + init_cfg=None): + super(SELayer, self).__init__(init_cfg) + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + if squeeze_channels is None: + squeeze_channels = make_divisible(channels // ratio, divisor) + assert isinstance(squeeze_channels, int) and squeeze_channels > 0, \ + '"squeeze_channels" should be a positive integer, but get ' + \ + f'{squeeze_channels} instead.' + self.return_weight = return_weight + self.conv1 = ConvModule( + in_channels=channels, + out_channels=squeeze_channels, + kernel_size=1, + stride=1, + bias=bias, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=squeeze_channels, + out_channels=channels, + kernel_size=1, + stride=1, + bias=bias, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + if self.return_weight: + return out + else: + return x * out diff --git a/mmpretrain/models/utils/sparse_modules.py b/mmpretrain/models/utils/sparse_modules.py new file mode 100644 index 0000000..dd6bf34 --- /dev/null +++ b/mmpretrain/models/utils/sparse_modules.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) ByteDance, Inc. and its affiliates. All rights reserved. +# Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS + + +class SparseHelper: + """The helper to compute sparse operation with pytorch, such as sparse + convlolution, sparse batch norm, etc.""" + + _cur_active: torch.Tensor = None + + @staticmethod + def _get_active_map_or_index(H: int, + returning_active_map: bool = True + ) -> torch.Tensor: + """Get current active map with (B, 1, f, f) shape or index format.""" + # _cur_active with shape (B, 1, f, f) + downsample_raito = H // SparseHelper._cur_active.shape[-1] + active_ex = SparseHelper._cur_active.repeat_interleave( + downsample_raito, 2).repeat_interleave(downsample_raito, 3) + return active_ex if returning_active_map else active_ex.squeeze( + 1).nonzero(as_tuple=True) + + @staticmethod + def sp_conv_forward(self, x: torch.Tensor) -> torch.Tensor: + """Sparse convolution forward function.""" + x = super(type(self), self).forward(x) + + # (b, c, h, w) *= (b, 1, h, w), mask the output of conv + x *= SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=True) + return x + + @staticmethod + def sp_bn_forward(self, x: torch.Tensor) -> torch.Tensor: + """Sparse batch norm forward function.""" + active_index = SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=False) + + # (b, c, h, w) -> (b, h, w, c) + x_permuted = x.permute(0, 2, 3, 1) + + # select the features on non-masked positions to form flatten features + # with shape (n, c) + x_flattened = x_permuted[active_index] + + # use BN1d to normalize this flatten feature (n, c) + x_flattened = super(type(self), self).forward(x_flattened) + + # generate output + output = torch.zeros_like(x_permuted, dtype=x_flattened.dtype) + output[active_index] = x_flattened + + # (b, h, w, c) -> (b, c, h, w) + output = output.permute(0, 3, 1, 2) + return output + + +class SparseConv2d(nn.Conv2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +class SparseMaxPooling(nn.MaxPool2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +class SparseAvgPooling(nn.AvgPool2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +@MODELS.register_module() +class SparseBatchNorm2d(nn.BatchNorm1d): + """hack: override the forward function. + See `sp_bn_forward` above for more details + """ + forward = SparseHelper.sp_bn_forward + + +@MODELS.register_module() +class SparseSyncBatchNorm2d(nn.SyncBatchNorm): + """hack: override the forward function. + See `sp_bn_forward` above for more details + """ + forward = SparseHelper.sp_bn_forward + + +@MODELS.register_module('SparseLN2d') +class SparseLayerNorm2D(nn.LayerNorm): + """Implementation of sparse LayerNorm on channels for 2d images.""" + + def forward(self, + x: torch.Tensor, + data_format='channel_first') -> torch.Tensor: + """Sparse layer norm forward function with 2D data. + + Args: + x (torch.Tensor): The input tensor. + data_format (str): The format of the input tensor. If + ``"channel_first"``, the shape of the input tensor should be + (B, C, H, W). If ``"channel_last"``, the shape of the input + tensor should be (B, H, W, C). Defaults to "channel_first". + """ + assert x.dim() == 4, ( + f'LayerNorm2d only supports inputs with shape ' + f'(N, C, H, W), but got tensor with shape {x.shape}') + if data_format == 'channel_last': + index = SparseHelper._get_active_map_or_index( + H=x.shape[1], returning_active_map=False) + + # select the features on non-masked positions to form flatten + # features with shape (n, c) + x_flattened = x[index] + # use LayerNorm to normalize this flatten feature (n, c) + x_flattened = super().forward(x_flattened) + + # generate output + x = torch.zeros_like(x, dtype=x_flattened.dtype) + x[index] = x_flattened + elif data_format == 'channel_first': + index = SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=False) + x_permuted = x.permute(0, 2, 3, 1) + + # select the features on non-masked positions to form flatten + # features with shape (n, c) + x_flattened = x_permuted[index] + # use LayerNorm to normalize this flatten feature (n, c) + x_flattened = super().forward(x_flattened) + + # generate output + x = torch.zeros_like(x_permuted, dtype=x_flattened.dtype) + x[index] = x_flattened + x = x.permute(0, 3, 1, 2).contiguous() + else: + raise NotImplementedError + return x diff --git a/mmpretrain/models/utils/swiglu_ffn.py b/mmpretrain/models/utils/swiglu_ffn.py new file mode 100644 index 0000000..20b4591 --- /dev/null +++ b/mmpretrain/models/utils/swiglu_ffn.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.drop import build_dropout + +from .layer_scale import LayerScale +from .norm import build_norm_layer + + +class SwiGLUFFN(nn.Module): + """SwiGLU FFN layer. + + Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py + """ # noqa + + def __init__( + self, + embed_dims: int, + feedforward_channels: Optional[int] = None, + out_dims: Optional[int] = None, + layer_scale_init_value: float = 0., + bias: bool = True, + dropout_layer: Optional[dict] = None, + norm_cfg: Optional[dict] = None, + add_identity: bool = True, + ) -> None: + super().__init__() + self.embed_dims = embed_dims + self.out_dims = out_dims or embed_dims + hidden_dims = feedforward_channels or embed_dims + + self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, hidden_dims) + else: + self.norm = nn.Identity() + + self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias) + + if layer_scale_init_value > 0: + self.gamma2 = LayerScale( + dim=embed_dims, layer_scale_init_value=layer_scale_init_value) + else: + self.gamma2 = nn.Identity() + + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + def forward(self, + x: torch.Tensor, + identity: Optional[torch.Tensor] = None) -> torch.Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + hidden = self.norm(hidden) + out = self.w3(hidden) + out = self.gamma2(out) + out = self.dropout_layer(out) + + if self.out_dims != self.embed_dims or not self.add_identity: + # due to the dimension inconsistence or user setting + # not to apply residual operation + return out + + if identity is None: + identity = x + return identity + out + + +class SwiGLUFFNFused(SwiGLUFFN): + """SwiGLU FFN layer with fusing. + + Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py + """ # noqa + + def __init__( + self, + embed_dims: int, + feedforward_channels: Optional[int] = None, + out_dims: Optional[int] = None, + layer_scale_init_value: float = 0., + bias: bool = True, + ) -> None: + out_dims = out_dims or embed_dims + feedforward_channels = feedforward_channels or embed_dims + feedforward_channels = (int(feedforward_channels * 2 / 3) + 7) // 8 * 8 + super().__init__( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + out_dims=out_dims, + layer_scale_init_value=layer_scale_init_value, + bias=bias, + ) diff --git a/mmpretrain/models/utils/tokenizer.py b/mmpretrain/models/utils/tokenizer.py new file mode 100644 index 0000000..fddda43 --- /dev/null +++ b/mmpretrain/models/utils/tokenizer.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections +import os + +from mmengine.fileio import list_from_file +from transformers import (AutoTokenizer, BartTokenizer, BasicTokenizer, + BertTokenizer, BertTokenizerFast, LlamaTokenizer, + WordpieceTokenizer) + +from mmpretrain.registry import TOKENIZER +from .huggingface import register_hf_tokenizer + +register_hf_tokenizer(AutoTokenizer) +register_hf_tokenizer(LlamaTokenizer) +register_hf_tokenizer(BertTokenizer) + + +@register_hf_tokenizer() +class BlipTokenizer(BertTokenizerFast): + """"BlipTokenizer inherit BertTokenizerFast (fast, Rust-based).""" + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ): + os.environ['TOKENIZERS_PARALLELISM'] = 'true' + + tokenizer = super().from_pretrained( + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ) + tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']}) + return tokenizer + + +@register_hf_tokenizer() +class Blip2Tokenizer(BertTokenizer): + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ): + tokenizer = super().from_pretrained( + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ) + tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + return tokenizer + + +@register_hf_tokenizer() +class OFATokenizer(BartTokenizer): + + vocab_files_names = { + 'vocab_file': 'vocab.json', + 'merges_file': 'merges.txt' + } + + pretrained_vocab_files_map = { + 'vocab_file': { + 'OFA-Sys/OFA-tiny': + 'https://huggingface.co/OFA-Sys/OFA-tiny/blob/main/vocab.json', + 'OFA-Sys/OFA-medium': + 'https://huggingface.co/OFA-Sys/OFA-medium/blob/main/vocab.json', + 'OFA-Sys/OFA-base': + 'https://huggingface.co/OFA-Sys/OFA-base/blob/main/vocab.json', + 'OFA-Sys/OFA-large': + 'https://huggingface.co/OFA-Sys/OFA-large/blob/main/vocab.json', + }, + 'merges_file': { + 'OFA-Sys/OFA-tiny': + 'https://huggingface.co/OFA-Sys/OFA-tiny/blob/main/merges.txt', + 'OFA-Sys/OFA-medium': + 'https://huggingface.co/OFA-Sys/OFA-medium/blob/main/merges.txt', + 'OFA-Sys/OFA-base': + 'https://huggingface.co/OFA-Sys/OFA-base/blob/main/merges.txt', + 'OFA-Sys/OFA-large': + 'https://huggingface.co/OFA-Sys/OFA-large/blob/main/merges.txt', + }, + } + + max_model_input_sizes = { + 'OFA-Sys/OFA-tiny': 1024, + 'OFA-Sys/OFA-medium': 1024, + 'OFA-Sys/OFA-base': 1024, + 'OFA-Sys/OFA-large': 1024, + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ): + num_bins = kwargs.pop('num_bins', 1000) + tokenizer = super().from_pretrained( + pretrained_model_name_or_path, + *init_inputs, + **kwargs, + ) + length = len(tokenizer) + tokenizer.add_tokens([''.format(i) for i in range(8192)]) + tokenizer.code_offset = length + tokenizer.add_tokens([''.format(i) for i in range(num_bins)]) + tokenizer.bin_offset = length + 8192 + tokenizer.num_bins = num_bins + return tokenizer + + +@TOKENIZER.register_module() +class FullTokenizer(BertTokenizer): + """Runs end-to-end tokenziation.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = self.load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, unk_token='[UNK]', max_input_chars_per_word=200) + + def load_vocab(self, vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + vocab_list = list_from_file(vocab_file) + for token in vocab_list: + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_by_vocab(self, vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + def convert_tokens_to_ids(self, tokens): + return self.convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return self.convert_by_vocab(self.inv_vocab, ids) + + @staticmethod + def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): + """Converts a sequence of tokens (string) in a single string.""" + + def clean_up_tokenization(out_string): + """Clean up a list of simple English tokenization artifacts like + spaces before punctuations and abbreviated forms.""" + out_string = ( + out_string.replace(' .', '.').replace(' ?', '?').replace( + ' !', '!').replace(' ,', ',').replace(" ' ", "'").replace( + " n't", "n't").replace(" 'm", "'m").replace( + " 's", "'s").replace(" 've", + "'ve").replace(" 're", "'re")) + return out_string + + text = ' '.join(tokens).replace(' ##', '').strip() + if clean_up_tokenization_spaces: + clean_text = clean_up_tokenization(text) + return clean_text + else: + return text + + def vocab_size(self): + return len(self.vocab) diff --git a/mmpretrain/models/utils/vector_quantizer.py b/mmpretrain/models/utils/vector_quantizer.py new file mode 100644 index 0000000..7c2ea89 --- /dev/null +++ b/mmpretrain/models/utils/vector_quantizer.py @@ -0,0 +1,232 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2022 Microsoft +# Modified from +# https://github.com/microsoft/unilm/blob/master/beit2/norm_ema_quantizer.py +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from mmengine.dist import all_reduce + + +def ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, + decay: torch.Tensor) -> None: + """Update moving average.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def norm_ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, + decay: torch.Tensor) -> None: + """Update moving average with norm data.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + moving_avg.data.copy_(F.normalize(moving_avg.data, p=2, dim=-1)) + + +def sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor: + """Sample vectors according to the given number.""" + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num, ), device=device) + + return samples[indices] + + +def kmeans(samples: torch.Tensor, + num_clusters: int, + num_iters: int = 10, + use_cosine_sim: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + """Run k-means algorithm.""" + dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = F.normalize(new_means, p=2, dim=-1) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EmbeddingEMA(nn.Module): + """The codebook of embedding vectors. + + Args: + num_tokens (int): Number of embedding vectors in the codebook. + codebook_dim (int) : The dimension of embedding vectors in the + codebook. + kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + codebook_init_path (str): The initialization checkpoint for codebook. + Defaults to None. + """ + + def __init__(self, + num_tokens: int, + codebook_dim: int, + kmeans_init: bool = True, + codebook_init_path: Optional[str] = None): + super().__init__() + self.num_tokens = num_tokens + self.codebook_dim = codebook_dim + if codebook_init_path is None: + if not kmeans_init: + weight = torch.randn(num_tokens, codebook_dim) + weight = F.normalize(weight, p=2, dim=-1) + else: + weight = torch.zeros(num_tokens, codebook_dim) + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + else: + print(f'load init codebook weight from {codebook_init_path}') + codebook_ckpt_weight = torch.load( + codebook_init_path, map_location='cpu') + weight = codebook_ckpt_weight.clone() + self.register_buffer('initted', torch.Tensor([True])) + + self.weight = nn.Parameter(weight, requires_grad=False) + self.update = True + + @torch.jit.ignore + def init_embed_(self, data: torch.Tensor) -> None: + """Initialize embedding vectors of codebook.""" + if self.initted: + return + print('Performing K-means init for codebook') + embed, _ = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) + self.weight.data.copy_(embed) + self.initted.data.copy_(torch.Tensor([True])) + + def forward(self, embed_id: torch.Tensor) -> torch.Tensor: + """Get embedding vectors.""" + return F.embedding(embed_id, self.weight) + + +class NormEMAVectorQuantizer(nn.Module): + """Normed EMA vector quantizer module. + + Args: + num_embed (int): Number of embedding vectors in the codebook. Defaults + to 8192. + embed_dims (int) : The dimension of embedding vectors in the codebook. + Defaults to 32. + beta (float): The mutiplier for VectorQuantizer embedding loss. + Defaults to 1. + decay (float): The decay parameter of EMA. Defaults to 0.99. + statistic_code_usage (bool): Whether to use cluster_size to record + statistic. Defaults to True. + kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + codebook_init_path (str): The initialization checkpoint for codebook. + Defaults to None. + """ + + def __init__(self, + num_embed: int, + embed_dims: int, + beta: float, + decay: float = 0.99, + statistic_code_usage: bool = True, + kmeans_init: bool = True, + codebook_init_path: Optional[str] = None) -> None: + super().__init__() + self.codebook_dim = embed_dims + self.num_tokens = num_embed + self.beta = beta + self.decay = decay + + # learnable = True if orthogonal_reg_weight > 0 else False + self.embedding = EmbeddingEMA( + num_tokens=self.num_tokens, + codebook_dim=self.codebook_dim, + kmeans_init=kmeans_init, + codebook_init_path=codebook_init_path) + + self.statistic_code_usage = statistic_code_usage + if statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(num_embed)) + + def reset_cluster_size(self, device): + + if self.statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) + self.cluster_size = self.cluster_size.to(device) + + def forward(self, z): + """Forward function.""" + # reshape z -> (batch, height, width, channel) + z = rearrange(z, 'b c h w -> b h w c') + z = F.normalize(z, p=2, dim=-1) + z_flattened = z.reshape(-1, self.codebook_dim) + + self.embedding.init_embed_(z_flattened) + + # 'n d -> d n' + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + + if not self.training: + with torch.no_grad(): + cluster_size = encodings.sum(0) + all_reduce(cluster_size) + ema_inplace(self.cluster_size, cluster_size, self.decay) + + if self.training and self.embedding.update: + # update cluster size with EMA + bins = encodings.sum(0) + all_reduce(bins) + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = z_flattened.t() @ encodings + all_reduce(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = F.normalize(embed_normalized, p=2, dim=-1) + embed_normalized = torch.where(zero_mask[..., None], + self.embedding.weight, + embed_normalized) + + # Update embedding vectors with EMA + norm_ema_inplace(self.embedding.weight, embed_normalized, + self.decay) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, encoding_indices diff --git a/mmpretrain/registry.py b/mmpretrain/registry.py new file mode 100644 index 0000000..cac2bda --- /dev/null +++ b/mmpretrain/registry.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""MMPretrain provides 21 registry nodes to support using modules across +projects. Each node is a child of the root registry in MMEngine. + +More details can be found at +https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. +""" + +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +__all__ = [ + 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', 'HOOKS', 'LOG_PROCESSORS', + 'OPTIMIZERS', 'OPTIM_WRAPPERS', 'OPTIM_WRAPPER_CONSTRUCTORS', + 'PARAM_SCHEDULERS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', + 'MODEL_WRAPPERS', 'WEIGHT_INITIALIZERS', 'BATCH_AUGMENTS', 'TASK_UTILS', + 'METRICS', 'EVALUATORS', 'VISUALIZERS', 'VISBACKENDS' +] + +####################################################################### +# mmpretrain.engine # +####################################################################### + +# Runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry( + 'runner', + parent=MMENGINE_RUNNERS, + locations=['mmpretrain.engine'], +) +# Runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', + parent=MMENGINE_RUNNER_CONSTRUCTORS, + locations=['mmpretrain.engine'], +) +# Loops which define the training or test process, like `EpochBasedTrainLoop` +LOOPS = Registry( + 'loop', + parent=MMENGINE_LOOPS, + locations=['mmpretrain.engine'], +) +# Hooks to add additional functions during running, like `CheckpointHook` +HOOKS = Registry( + 'hook', + parent=MMENGINE_HOOKS, + locations=['mmpretrain.engine'], +) +# Log processors to process the scalar log data. +LOG_PROCESSORS = Registry( + 'log processor', + parent=MMENGINE_LOG_PROCESSORS, + locations=['mmpretrain.engine'], +) +# Optimizers to optimize the model weights, like `SGD` and `Adam`. +OPTIMIZERS = Registry( + 'optimizer', + parent=MMENGINE_OPTIMIZERS, + locations=['mmpretrain.engine'], +) +# Optimizer wrappers to enhance the optimization process. +OPTIM_WRAPPERS = Registry( + 'optimizer_wrapper', + parent=MMENGINE_OPTIM_WRAPPERS, + locations=['mmpretrain.engine'], +) +# Optimizer constructors to customize the hyperparameters of optimizers. +OPTIM_WRAPPER_CONSTRUCTORS = Registry( + 'optimizer wrapper constructor', + parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS, + locations=['mmpretrain.engine'], +) +# Parameter schedulers to dynamically adjust optimization parameters. +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', + parent=MMENGINE_PARAM_SCHEDULERS, + locations=['mmpretrain.engine'], +) + +####################################################################### +# mmpretrain.datasets # +####################################################################### + +# Datasets like `ImageNet` and `CIFAR10`. +DATASETS = Registry( + 'dataset', + parent=MMENGINE_DATASETS, + locations=['mmpretrain.datasets'], +) +# Samplers to sample the dataset. +DATA_SAMPLERS = Registry( + 'data sampler', + parent=MMENGINE_DATA_SAMPLERS, + locations=['mmpretrain.datasets'], +) +# Transforms to process the samples from the dataset. +TRANSFORMS = Registry( + 'transform', + parent=MMENGINE_TRANSFORMS, + locations=['mmpretrain.datasets'], +) + +####################################################################### +# mmpretrain.models # +####################################################################### + +# Neural network modules inheriting `nn.Module`. +MODELS = Registry( + 'model', + parent=MMENGINE_MODELS, + locations=['mmpretrain.models'], +) +# Model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry( + 'model_wrapper', + parent=MMENGINE_MODEL_WRAPPERS, + locations=['mmpretrain.models'], +) +# Weight initialization methods like uniform, xavier. +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', + parent=MMENGINE_WEIGHT_INITIALIZERS, + locations=['mmpretrain.models'], +) +# Batch augmentations like `Mixup` and `CutMix`. +BATCH_AUGMENTS = Registry( + 'batch augment', + locations=['mmpretrain.models'], +) +# Task-specific modules like anchor generators and box coders +TASK_UTILS = Registry( + 'task util', + parent=MMENGINE_TASK_UTILS, + locations=['mmpretrain.models'], +) +# Tokenizer to encode sequence +TOKENIZER = Registry( + 'tokenizer', + locations=['mmpretrain.models'], +) + +####################################################################### +# mmpretrain.evaluation # +####################################################################### + +# Metrics to evaluate the model prediction results. +METRICS = Registry( + 'metric', + parent=MMENGINE_METRICS, + locations=['mmpretrain.evaluation'], +) +# Evaluators to define the evaluation process. +EVALUATORS = Registry( + 'evaluator', + parent=MMENGINE_EVALUATOR, + locations=['mmpretrain.evaluation'], +) + +####################################################################### +# mmpretrain.visualization # +####################################################################### + +# Visualizers to display task-specific results. +VISUALIZERS = Registry( + 'visualizer', + parent=MMENGINE_VISUALIZERS, + locations=['mmpretrain.visualization'], +) +# Backends to save the visualization results, like TensorBoard, WandB. +VISBACKENDS = Registry( + 'vis_backend', + parent=MMENGINE_VISBACKENDS, + locations=['mmpretrain.visualization'], +) diff --git a/mmpretrain/structures/__init__.py b/mmpretrain/structures/__init__.py new file mode 100644 index 0000000..e7de863 --- /dev/null +++ b/mmpretrain/structures/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .data_sample import DataSample +from .multi_task_data_sample import MultiTaskDataSample +from .utils import (batch_label_to_onehot, cat_batch_labels, format_label, + format_score, label_to_onehot, tensor_split) + +__all__ = [ + 'DataSample', 'batch_label_to_onehot', 'cat_batch_labels', 'tensor_split', + 'MultiTaskDataSample', 'label_to_onehot', 'format_label', 'format_score' +] diff --git a/mmpretrain/structures/data_sample.py b/mmpretrain/structures/data_sample.py new file mode 100644 index 0000000..ce588b8 --- /dev/null +++ b/mmpretrain/structures/data_sample.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from multiprocessing.reduction import ForkingPickler +from typing import Union + +import numpy as np +import torch +from mmengine.structures import BaseDataElement + +from .utils import LABEL_TYPE, SCORE_TYPE, format_label, format_score + + +class DataSample(BaseDataElement): + """A general data structure interface. + + It's used as the interface between different components. + + The following fields are convention names in MMPretrain, and we will set or + get these fields in data transforms, models, and metrics if needed. You can + also set any new fields for your need. + + Meta fields: + img_shape (Tuple): The shape of the corresponding input image. + ori_shape (Tuple): The original shape of the corresponding image. + sample_idx (int): The index of the sample in the dataset. + num_classes (int): The number of all categories. + + Data fields: + gt_label (tensor): The ground truth label. + gt_score (tensor): The ground truth score. + pred_label (tensor): The predicted label. + pred_score (tensor): The predicted score. + mask (tensor): The mask used in masked image modeling. + + Examples: + >>> import torch + >>> from mmpretrain.structures import DataSample + >>> + >>> img_meta = dict(img_shape=(960, 720), num_classes=5) + >>> data_sample = DataSample(metainfo=img_meta) + >>> data_sample.set_gt_label(3) + >>> print(data_sample) + + >>> + >>> # For multi-label data + >>> data_sample = DataSample().set_gt_label([0, 1, 4]) + >>> print(data_sample) + + >>> + >>> # Set one-hot format score + >>> data_sample = DataSample().set_pred_score([0.1, 0.1, 0.6, 0.1]) + >>> print(data_sample) + + >>> + >>> # Set custom field + >>> data_sample = DataSample() + >>> data_sample.my_field = [1, 2, 3] + >>> print(data_sample) + + >>> print(data_sample.my_field) + [1, 2, 3] + """ + + def set_gt_label(self, value: LABEL_TYPE) -> 'DataSample': + """Set ``gt_label``.""" + self.set_field(format_label(value), 'gt_label', dtype=torch.Tensor) + return self + + def set_gt_score(self, value: SCORE_TYPE) -> 'DataSample': + """Set ``gt_score``.""" + score = format_score(value) + self.set_field(score, 'gt_score', dtype=torch.Tensor) + if hasattr(self, 'num_classes'): + assert len(score) == self.num_classes, \ + f'The length of score {len(score)} should be '\ + f'equal to the num_classes {self.num_classes}.' + else: + self.set_field( + name='num_classes', value=len(score), field_type='metainfo') + return self + + def set_pred_label(self, value: LABEL_TYPE) -> 'DataSample': + """Set ``pred_label``.""" + self.set_field(format_label(value), 'pred_label', dtype=torch.Tensor) + return self + + def set_pred_score(self, value: SCORE_TYPE): + """Set ``pred_label``.""" + score = format_score(value) + self.set_field(score, 'pred_score', dtype=torch.Tensor) + if hasattr(self, 'num_classes'): + assert len(score) == self.num_classes, \ + f'The length of score {len(score)} should be '\ + f'equal to the num_classes {self.num_classes}.' + else: + self.set_field( + name='num_classes', value=len(score), field_type='metainfo') + return self + + def set_mask(self, value: Union[torch.Tensor, np.ndarray]): + if isinstance(value, np.ndarray): + value = torch.from_numpy(value) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Invalid mask type {type(value)}') + self.set_field(value, 'mask', dtype=torch.Tensor) + return self + + def __repr__(self) -> str: + """Represent the object.""" + + def dump_items(items, prefix=''): + return '\n'.join(f'{prefix}{k}: {v}' for k, v in items) + + repr_ = '' + if len(self._metainfo_fields) > 0: + repr_ += '\n\nMETA INFORMATION\n' + repr_ += dump_items(self.metainfo_items(), prefix=' ' * 4) + if len(self._data_fields) > 0: + repr_ += '\n\nDATA FIELDS\n' + repr_ += dump_items(self.items(), prefix=' ' * 4) + + repr_ = f'<{self.__class__.__name__}({repr_}\n\n) at {hex(id(self))}>' + return repr_ + + +def _reduce_datasample(data_sample): + """reduce DataSample.""" + attr_dict = data_sample.__dict__ + convert_keys = [] + for k, v in attr_dict.items(): + if isinstance(v, torch.Tensor): + attr_dict[k] = v.numpy() + convert_keys.append(k) + return _rebuild_datasample, (attr_dict, convert_keys) + + +def _rebuild_datasample(attr_dict, convert_keys): + """rebuild DataSample.""" + data_sample = DataSample() + for k in convert_keys: + attr_dict[k] = torch.from_numpy(attr_dict[k]) + data_sample.__dict__ = attr_dict + return data_sample + + +# Due to the multi-processing strategy of PyTorch, DataSample may consume many +# file descriptors because it contains multiple tensors. Here we overwrite the +# reduce function of DataSample in ForkingPickler and convert these tensors to +# np.ndarray during pickling. It may slightly influence the performance of +# dataloader. +ForkingPickler.register(DataSample, _reduce_datasample) diff --git a/mmpretrain/structures/multi_task_data_sample.py b/mmpretrain/structures/multi_task_data_sample.py new file mode 100644 index 0000000..f009938 --- /dev/null +++ b/mmpretrain/structures/multi_task_data_sample.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmengine.structures import BaseDataElement + + +class MultiTaskDataSample(BaseDataElement): + + @property + def tasks(self): + return self._data_fields diff --git a/mmpretrain/structures/utils.py b/mmpretrain/structures/utils.py new file mode 100644 index 0000000..a4f9e95 --- /dev/null +++ b/mmpretrain/structures/utils.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.utils import is_str + +if hasattr(torch, 'tensor_split'): + tensor_split = torch.tensor_split +else: + # A simple implementation of `tensor_split`. + def tensor_split(input: torch.Tensor, indices: list): + outs = [] + for start, end in zip([0] + indices, indices + [input.size(0)]): + outs.append(input[start:end]) + return outs + + +LABEL_TYPE = Union[torch.Tensor, np.ndarray, Sequence, int] +SCORE_TYPE = Union[torch.Tensor, np.ndarray, Sequence] + + +def format_label(value: LABEL_TYPE) -> torch.Tensor: + """Convert various python types to label-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence | int): Label value. + + Returns: + :obj:`torch.Tensor`: The foramtted label tensor. + """ + + # Handle single number + if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0: + value = int(value.item()) + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).to(torch.long) + elif isinstance(value, Sequence) and not is_str(value): + value = torch.tensor(value).to(torch.long) + elif isinstance(value, int): + value = torch.LongTensor([value]) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + return value + + +def format_score(value: SCORE_TYPE) -> torch.Tensor: + """Convert various python types to score-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence): Score values. + + Returns: + :obj:`torch.Tensor`: The foramtted score tensor. + """ + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).float() + elif isinstance(value, Sequence) and not is_str(value): + value = torch.tensor(value).float() + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + return value + + +def cat_batch_labels(elements: List[torch.Tensor]): + """Concat a batch of label tensor to one tensor. + + Args: + elements (List[tensor]): A batch of labels. + + Returns: + Tuple[torch.Tensor, List[int]]: The first item is the concated label + tensor, and the second item is the split indices of every sample. + """ + labels = [] + splits = [0] + for element in elements: + labels.append(element) + splits.append(splits[-1] + element.size(0)) + batch_label = torch.cat(labels) + return batch_label, splits[1:-1] + + +def batch_label_to_onehot(batch_label, split_indices, num_classes): + """Convert a concated label tensor to onehot format. + + Args: + batch_label (torch.Tensor): A concated label tensor from multiple + samples. + split_indices (List[int]): The split indices of every sample. + num_classes (int): The number of classes. + + Returns: + torch.Tensor: The onehot format label tensor. + + Examples: + >>> import torch + >>> from mmpretrain.structures import batch_label_to_onehot + >>> # Assume a concated label from 3 samples. + >>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1] + >>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1]) + >>> split_indices = [2, 5] + >>> batch_label_to_onehot(batch_label, split_indices, num_classes=5) + tensor([[1, 1, 0, 0, 0], + [1, 0, 1, 0, 1], + [0, 1, 0, 1, 0]]) + """ + sparse_onehot_list = F.one_hot(batch_label, num_classes) + onehot_list = [ + sparse_onehot.sum(0) + for sparse_onehot in tensor_split(sparse_onehot_list, split_indices) + ] + return torch.stack(onehot_list) + + +def label_to_onehot(label: LABEL_TYPE, num_classes: int): + """Convert a label to onehot format tensor. + + Args: + label (LABEL_TYPE): Label value. + num_classes (int): The number of classes. + + Returns: + torch.Tensor: The onehot format label tensor. + + Examples: + >>> import torch + >>> from mmpretrain.structures import label_to_onehot + >>> # Single-label + >>> label_to_onehot(1, num_classes=5) + tensor([0, 1, 0, 0, 0]) + >>> # Multi-label + >>> label_to_onehot([0, 2, 3], num_classes=5) + tensor([1, 0, 1, 1, 0]) + """ + label = format_label(label) + sparse_onehot = F.one_hot(label, num_classes) + return sparse_onehot.sum(0) diff --git a/mmpretrain/utils/__init__.py b/mmpretrain/utils/__init__.py new file mode 100644 index 0000000..991e321 --- /dev/null +++ b/mmpretrain/utils/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .analyze import load_json_log +from .collect_env import collect_env +from .dependency import require +from .misc import get_ori_model +from .progress import track, track_on_main_process +from .setup_env import register_all_modules + +__all__ = [ + 'collect_env', 'register_all_modules', 'track_on_main_process', + 'load_json_log', 'get_ori_model', 'track', 'require' +] diff --git a/mmpretrain/utils/analyze.py b/mmpretrain/utils/analyze.py new file mode 100644 index 0000000..a933591 --- /dev/null +++ b/mmpretrain/utils/analyze.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json + + +def load_json_log(json_log): + """load and convert json_logs to log_dicts. + + Args: + json_log (str): The path of the json log file. + + Returns: + dict: The result dict contains two items, "train" and "val", for + the training log and validate log. + + Example: + An example output: + + .. code-block:: python + + { + 'train': [ + {"lr": 0.1, "time": 0.02, "epoch": 1, "step": 100}, + {"lr": 0.1, "time": 0.02, "epoch": 1, "step": 200}, + {"lr": 0.1, "time": 0.02, "epoch": 1, "step": 300}, + ... + ] + 'val': [ + {"accuracy/top1": 32.1, "step": 1}, + {"accuracy/top1": 50.2, "step": 2}, + {"accuracy/top1": 60.3, "step": 2}, + ... + ] + } + """ + log_dict = dict(train=[], val=[]) + with open(json_log, 'r') as log_file: + for line in log_file: + log = json.loads(line.strip()) + # A hack trick to determine whether the line is training log. + mode = 'train' if 'lr' in log else 'val' + log_dict[mode].append(log) + + return log_dict diff --git a/mmpretrain/utils/collect_env.py b/mmpretrain/utils/collect_env.py new file mode 100644 index 0000000..988451e --- /dev/null +++ b/mmpretrain/utils/collect_env.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env as collect_base_env + +import mmpretrain + + +def collect_env(with_torch_comiling_info=False): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMCV'] = mmcv.__version__ + if not with_torch_comiling_info: + env_info.pop('PyTorch compiling details') + env_info['MMPreTrain'] = mmpretrain.__version__ + '+' + get_git_hash()[:7] + return env_info diff --git a/mmpretrain/utils/dependency.py b/mmpretrain/utils/dependency.py new file mode 100644 index 0000000..0e3d8ae --- /dev/null +++ b/mmpretrain/utils/dependency.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from functools import wraps +from inspect import isfunction + +from importlib_metadata import PackageNotFoundError, distribution +from mmengine.utils import digit_version + + +def satisfy_requirement(dep): + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, dep, maxsplit=1) + parts = [p.strip() for p in parts] + package = parts[0] + if len(parts) > 1: + op, version = parts[1:] + op = { + '>=': '__ge__', + '==': '__eq__', + '>': '__gt__', + '<': '__lt__', + '<=': '__le__' + }[op] + else: + op, version = None, None + + try: + dist = distribution(package) + if op is None or getattr(digit_version(dist.version), op)( + digit_version(version)): + return True + except PackageNotFoundError: + pass + + return False + + +def require(dep, install=None): + """A wrapper of function for extra package requirements. + + Args: + dep (str): The dependency package name, like ``transformers`` + or ``transformers>=4.28.0``. + install (str, optional): The installation command hint. Defaults + to None, which means to use "pip install dep". + """ + + def wrapper(fn): + assert isfunction(fn) + + @wraps(fn) + def ask_install(*args, **kwargs): + name = fn.__qualname__.replace('.__init__', '') + ins = install or f'pip install "{dep}"' + raise ImportError( + f'{name} requires {dep}, please install it by `{ins}`.') + + if satisfy_requirement(dep): + fn._verify_require = getattr(fn, '_verify_require', lambda: None) + return fn + + ask_install._verify_require = ask_install + return ask_install + + return wrapper + + +WITH_MULTIMODAL = all( + satisfy_requirement(item) + for item in ['pycocotools', 'transformers>=4.28.0']) + + +def register_multimodal_placeholder(names, registry): + for name in names: + + def ask_install(*args, **kwargs): + raise ImportError( + f'{name} requires extra multi-modal dependencies, please ' + 'install it by `pip install "mmpretrain[multimodal]"` ' + 'or `pip install -e ".[multimodal]"`.') + + registry.register_module(name=name, module=ask_install) diff --git a/mmpretrain/utils/misc.py b/mmpretrain/utils/misc.py new file mode 100644 index 0000000..cc53267 --- /dev/null +++ b/mmpretrain/utils/misc.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmengine.model import is_model_wrapper + + +def get_ori_model(model: nn.Module) -> nn.Module: + """Get original model if the input model is a model wrapper. + + Args: + model (nn.Module): A model may be a model wrapper. + + Returns: + nn.Module: The model without model wrapper. + """ + if is_model_wrapper(model): + return model.module + else: + return model diff --git a/mmpretrain/utils/progress.py b/mmpretrain/utils/progress.py new file mode 100644 index 0000000..b23f976 --- /dev/null +++ b/mmpretrain/utils/progress.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import mmengine.dist as dist +import rich.progress as progress +from rich.live import Live + +disable_progress_bar = False +global_progress = progress.Progress( + '{task.description}', + progress.BarColumn(), + progress.TaskProgressColumn(show_speed=True), + progress.TimeRemainingColumn(), +) +global_live = Live(global_progress, refresh_per_second=10) + + +def track(sequence, description: str = '', total: Optional[float] = None): + if disable_progress_bar: + yield from sequence + else: + global_live.start() + task_id = global_progress.add_task(description, total=total) + task = global_progress._tasks[task_id] + try: + yield from global_progress.track(sequence, task_id=task_id) + finally: + if task.total is None: + global_progress.update(task_id, total=task.completed) + if all(task.finished for task in global_progress.tasks): + global_live.stop() + for task_id in global_progress.task_ids: + global_progress.remove_task(task_id) + + +def track_on_main_process(sequence, description='', total=None): + if not dist.is_main_process() or disable_progress_bar: + yield from sequence + else: + yield from track(sequence, total=total, description=description) diff --git a/mmpretrain/utils/setup_env.py b/mmpretrain/utils/setup_env.py new file mode 100644 index 0000000..1b57b84 --- /dev/null +++ b/mmpretrain/utils/setup_env.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import warnings + +from mmengine import DefaultScope + + +def register_all_modules(init_default_scope: bool = True) -> None: + """Register all modules in mmpretrain into the registries. + + Args: + init_default_scope (bool): Whether initialize the mmpretrain default + scope. If True, the global default scope will be set to + `mmpretrain`, and all registries will build modules from + mmpretrain's registry node. To understand more about the registry, + please refer to + https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa: E501 + import mmpretrain.datasets # noqa: F401,F403 + import mmpretrain.engine # noqa: F401,F403 + import mmpretrain.evaluation # noqa: F401,F403 + import mmpretrain.models # noqa: F401,F403 + import mmpretrain.structures # noqa: F401,F403 + import mmpretrain.visualization # noqa: F401,F403 + + if not init_default_scope: + return + + current_scope = DefaultScope.get_current_instance() + if current_scope is None: + DefaultScope.get_instance('mmpretrain', scope_name='mmpretrain') + elif current_scope.scope_name != 'mmpretrain': + warnings.warn( + f'The current default scope "{current_scope.scope_name}" ' + 'is not "mmpretrain", `register_all_modules` will force ' + 'the current default scope to be "mmpretrain". If this is ' + 'not expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmpretrain-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmpretrain') diff --git a/mmpretrain/version.py b/mmpretrain/version.py new file mode 100644 index 0000000..1822b7f --- /dev/null +++ b/mmpretrain/version.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved + +__version__ = '1.2.0' + + +def parse_version_info(version_str): + """Parse a version string into a tuple. + + Args: + version_str (str): The version string. + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). + """ + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) + +__all__ = ['__version__', 'version_info', 'parse_version_info'] diff --git a/mmpretrain/visualization/__init__.py b/mmpretrain/visualization/__init__.py new file mode 100644 index 0000000..0dbeecf --- /dev/null +++ b/mmpretrain/visualization/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .utils import create_figure, get_adaptive_scale +from .visualizer import UniversalVisualizer + +__all__ = ['UniversalVisualizer', 'get_adaptive_scale', 'create_figure'] diff --git a/mmpretrain/visualization/utils.py b/mmpretrain/visualization/utils.py new file mode 100644 index 0000000..91a1d81 --- /dev/null +++ b/mmpretrain/visualization/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import TYPE_CHECKING, Tuple + +if TYPE_CHECKING: + from matplotlib.figure import Figure + + +def get_adaptive_scale(img_shape: Tuple[int, int], + min_scale: float = 0.3, + max_scale: float = 3.0) -> float: + """Get adaptive scale according to image shape. + + The target scale depends on the the short edge length of the image. If the + short edge length equals 224, the output is 1.0. And output linear scales + according the short edge length. + + You can also specify the minimum scale and the maximum scale to limit the + linear scale. + + Args: + img_shape (Tuple[int, int]): The shape of the canvas image. + min_size (int): The minimum scale. Defaults to 0.3. + max_size (int): The maximum scale. Defaults to 3.0. + + Returns: + int: The adaptive scale. + """ + short_edge_length = min(img_shape) + scale = short_edge_length / 224. + return min(max(scale, min_scale), max_scale) + + +def create_figure(*args, margin=False, **kwargs) -> 'Figure': + """Create a independent figure. + + Different from the :func:`plt.figure`, the figure from this function won't + be managed by matplotlib. And it has + :obj:`matplotlib.backends.backend_agg.FigureCanvasAgg`, and therefore, you + can use the ``canvas`` attribute to get access the drawn image. + + Args: + *args: All positional arguments of :class:`matplotlib.figure.Figure`. + margin: Whether to reserve the white edges of the figure. + Defaults to False. + **kwargs: All keyword arguments of :class:`matplotlib.figure.Figure`. + + Return: + matplotlib.figure.Figure: The created figure. + """ + from matplotlib.backends.backend_agg import FigureCanvasAgg + from matplotlib.figure import Figure + + figure = Figure(*args, **kwargs) + FigureCanvasAgg(figure) + + if not margin: + # remove white edges by set subplot margin + figure.subplots_adjust(left=0, right=1, bottom=0, top=1) + + return figure diff --git a/mmpretrain/visualization/visualizer.py b/mmpretrain/visualization/visualizer.py new file mode 100644 index 0000000..5d18ca8 --- /dev/null +++ b/mmpretrain/visualization/visualizer.py @@ -0,0 +1,777 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.dataset import BaseDataset +from mmengine.dist import master_only +from mmengine.visualization import Visualizer +from mmengine.visualization.utils import img_from_canvas + +from mmpretrain.registry import VISUALIZERS +from mmpretrain.structures import DataSample +from .utils import create_figure, get_adaptive_scale + + +@VISUALIZERS.register_module() +class UniversalVisualizer(Visualizer): + """Universal Visualizer for multiple tasks. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + fig_save_cfg (dict): Keyword parameters of figure for saving. + Defaults to empty dict. + fig_show_cfg (dict): Keyword parameters of figure for showing. + Defaults to empty dict. + """ + DEFAULT_TEXT_CFG = { + 'family': 'monospace', + 'color': 'white', + 'bbox': dict(facecolor='black', alpha=0.5, boxstyle='Round'), + 'verticalalignment': 'top', + 'horizontalalignment': 'left', + } + + @master_only + def visualize_cls(self, + image: np.ndarray, + data_sample: DataSample, + classes: Optional[Sequence[str]] = None, + draw_gt: bool = True, + draw_pred: bool = True, + draw_score: bool = True, + resize: Optional[int] = None, + rescale_factor: Optional[float] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: str = '', + step: int = 0) -> None: + """Visualize image classification result. + + This method will draw an text box on the input image to visualize the + information about image classification, like the ground-truth label and + prediction label. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + classes (Sequence[str], optional): The categories names. + Defaults to None. + draw_gt (bool): Whether to draw ground-truth labels. + Defaults to True. + draw_pred (bool): Whether to draw prediction labels. + Defaults to True. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + rescale_factor (float, optional): Rescale the image by the rescale + factor before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts + arguments of :meth:`mmengine.Visualizer.draw_texts`. + Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + if self.dataset_meta is not None: + classes = classes or self.dataset_meta.get('classes', None) + + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + elif rescale_factor is not None: + image = mmcv.imrescale(image, rescale_factor) + + texts = [] + self.set_image(image) + + if draw_gt and 'gt_label' in data_sample: + idx = data_sample.gt_label.tolist() + class_labels = [''] * len(idx) + if classes is not None: + class_labels = [f' ({classes[i]})' for i in idx] + labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))] + prefix = 'Ground truth: ' + texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) + + if draw_pred and 'pred_label' in data_sample: + idx = data_sample.pred_label.tolist() + score_labels = [''] * len(idx) + class_labels = [''] * len(idx) + if draw_score and 'pred_score' in data_sample: + score_labels = [ + f', {data_sample.pred_score[i].item():.2f}' for i in idx + ] + + if classes is not None: + class_labels = [f' ({classes[i]})' for i in idx] + + labels = [ + str(idx[i]) + score_labels[i] + class_labels[i] + for i in range(len(idx)) + ] + prefix = 'Prediction: ' + texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + self.ax_save.text( + img_scale * 5, + img_scale * 5, + '\n'.join(texts), + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_image_retrieval(self, + image: np.ndarray, + data_sample: DataSample, + prototype_dataset: BaseDataset, + topk: int = 1, + draw_score: bool = True, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize image retrieval result. + + This method will draw the input image and the images retrieved from the + prototype dataset. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + prototype_dataset (:obj:`BaseDataset`): The prototype dataset. + It should have `get_data_info` method and return a dict + includes `img_path`. + draw_score (bool): Whether to draw the match scores of the + retrieved images. Defaults to True. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + if resize is not None: + image = mmcv.imrescale(image, (resize, resize)) + + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + + figure = create_figure(margin=True) + gs = figure.add_gridspec(2, topk) + query_plot = figure.add_subplot(gs[0, :]) + query_plot.axis(False) + query_plot.imshow(image) + + for k, (score, sample_idx) in enumerate(zip(match_scores, indices)): + sample = prototype_dataset.get_data_info(sample_idx.item()) + value_image = mmcv.imread(sample['img_path'])[..., ::-1] + value_plot = figure.add_subplot(gs[1, k]) + value_plot.axis(False) + value_plot.imshow(value_image) + if draw_score: + value_plot.text( + 5, + 5, + f'{score:.2f}', + **text_cfg, + ) + drawn_img = img_from_canvas(figure.canvas) + self.set_image(drawn_img) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + def add_mask_to_image( + self, + image: np.ndarray, + data_sample: DataSample, + resize: Union[int, Tuple[int]] = 224, + color: Union[str, Tuple[int]] = 'black', + alpha: Union[int, float] = 0.8, + ) -> np.ndarray: + if isinstance(resize, int): + resize = (resize, resize) + + image = mmcv.imresize(image, resize) + self.set_image(image) + + if isinstance(data_sample.mask, np.ndarray): + data_sample.mask = torch.tensor(data_sample.mask) + mask = data_sample.mask.float()[None, None, ...] + mask_ = F.interpolate(mask, image.shape[:2], mode='nearest')[0, 0] + + self.draw_binary_masks(mask_.bool(), colors=color, alphas=alpha) + + drawn_img = self.get_image() + return drawn_img + + @master_only + def visualize_masked_image(self, + image: np.ndarray, + data_sample: DataSample, + resize: Union[int, Tuple[int]] = 224, + color: Union[str, Tuple[int]] = 'black', + alpha: Union[int, float] = 0.8, + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: str = '', + step: int = 0) -> None: + """Visualize masked image. + + This method will draw an image with binary mask. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int | Tuple[int]): Resize the input image to the specified + shape. Defaults to 224. + color (str | Tuple[int]): The color of the binary mask. + Defaults to "black". + alpha (int | float): The transparency of the mask. Defaults to 0.8. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + drawn_img = self.add_mask_to_image( + image=image, + data_sample=data_sample, + resize=resize, + color=color, + alpha=alpha) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_image_caption(self, + image: np.ndarray, + data_sample: DataSample, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize image caption result. + + This method will draw the input image and the images caption. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + + self.set_image(image) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + self.ax_save.text( + img_scale * 5, + img_scale * 5, + data_sample.get('pred_caption'), + wrap=True, + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_vqa(self, + image: np.ndarray, + data_sample: DataSample, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize visual question answering result. + + This method will draw the input image, question and answer. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + + self.set_image(image) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + text = (f'Q: {data_sample.get("question")}\n' + f'A: {data_sample.get("pred_answer")}') + self.ax_save.text( + img_scale * 5, + img_scale * 5, + text, + wrap=True, + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_visual_grounding(self, + image: np.ndarray, + data_sample: DataSample, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + line_width: Union[int, float] = 3, + bbox_color: Union[str, tuple] = 'green', + step: int = 0) -> None: + """Visualize visual grounding result. + + This method will draw the input image, bbox and the object. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + gt_bboxes = data_sample.get('gt_bboxes') + pred_bboxes = data_sample.get('pred_bboxes') + if resize is not None: + h, w = image.shape[:2] + if w < h: + image, w_scale, h_scale = mmcv.imresize( + image, (resize, resize * h // w), return_scale=True) + else: + image, w_scale, h_scale = mmcv.imresize( + image, (resize * w // h, resize), return_scale=True) + pred_bboxes[:, ::2] *= w_scale + pred_bboxes[:, 1::2] *= h_scale + if gt_bboxes is not None: + gt_bboxes[:, ::2] *= w_scale + gt_bboxes[:, 1::2] *= h_scale + + self.set_image(image) + # Avoid the line-width limit in the base classes. + self._default_font_size = 1e3 + self.draw_bboxes( + pred_bboxes, line_widths=line_width, edge_colors=bbox_color) + if gt_bboxes is not None: + self.draw_bboxes( + gt_bboxes, line_widths=line_width, edge_colors='blue') + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + + text_positions = pred_bboxes[:, :2] + line_width + for i in range(pred_bboxes.size(0)): + self.ax_save.text( + text_positions[i, 0] + line_width, + text_positions[i, 1] + line_width, + data_sample.get('text'), + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_t2i_retrieval(self, + text: str, + data_sample: DataSample, + prototype_dataset: BaseDataset, + topk: int = 1, + draw_score: bool = True, + text_cfg: dict = dict(), + fig_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize Text-To-Image retrieval result. + + This method will draw the input text and the images retrieved from the + prototype dataset. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + prototype_dataset (:obj:`BaseDataset`): The prototype dataset. + It should have `get_data_info` method and return a dict + includes `img_path`. + topk (int): To visualize the topk matching items. Defaults to 1. + draw_score (bool): Whether to draw the match scores of the + retrieved images. Defaults to True. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + fig_cfg (dict): Extra figure setting, which accepts arguments of + :func:`plt.Figure`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + + figure = create_figure(margin=True, **fig_cfg) + figure.suptitle(text) + gs = figure.add_gridspec(1, topk) + + for k, (score, sample_idx) in enumerate(zip(match_scores, indices)): + sample = prototype_dataset.get_data_info(sample_idx.item()) + value_image = mmcv.imread(sample['img_path'])[..., ::-1] + value_plot = figure.add_subplot(gs[0, k]) + value_plot.axis(False) + value_plot.imshow(value_image) + if draw_score: + value_plot.text( + 5, + 5, + f'{score:.2f}', + **text_cfg, + ) + drawn_img = img_from_canvas(figure.canvas) + self.set_image(drawn_img) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_i2t_retrieval(self, + image: np.ndarray, + data_sample: DataSample, + prototype_dataset: Sequence[str], + topk: int = 1, + draw_score: bool = True, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: str = '', + step: int = 0) -> None: + """Visualize Image-To-Text retrieval result. + + This method will draw the input image and the texts retrieved from the + prototype dataset. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + prototype_dataset (Sequence[str]): The prototype dataset. + It should be a list of texts. + topk (int): To visualize the topk matching items. Defaults to 1. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts + arguments of :meth:`mmengine.Visualizer.draw_texts`. + Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h // w)) + else: + image = mmcv.imresize(image, (resize * w // h, resize)) + + self.set_image(image) + + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + texts = [] + for score, sample_idx in zip(match_scores, indices): + text = prototype_dataset[sample_idx.item()] + if draw_score: + text = f'{score:.2f} ' + text + texts.append(text) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + self.ax_save.text( + img_scale * 5, + img_scale * 5, + '\n'.join(texts), + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img -- GitLab