Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
142dcf29
Commit
142dcf29
authored
Apr 15, 2022
by
hepj
Browse files
增加conformer代码
parent
7f99c1c3
Changes
317
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5950 additions
and
0 deletions
+5950
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/lvis.py
...rch/NLP/Conformer-main/mmdetection/mmdet/datasets/lvis.py
+742
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/__init__.py
...mer-main/mmdetection/mmdet/datasets/pipelines/__init__.py
+25
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/auto_augment.py
...main/mmdetection/mmdet/datasets/pipelines/auto_augment.py
+890
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/compose.py
...rmer-main/mmdetection/mmdet/datasets/pipelines/compose.py
+51
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/formating.py
...er-main/mmdetection/mmdet/datasets/pipelines/formating.py
+364
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/instaboost.py
...r-main/mmdetection/mmdet/datasets/pipelines/instaboost.py
+98
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/loading.py
...rmer-main/mmdetection/mmdet/datasets/pipelines/loading.py
+458
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/test_time_aug.py
...ain/mmdetection/mmdet/datasets/pipelines/test_time_aug.py
+119
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/transforms.py
...r-main/mmdetection/mmdet/datasets/pipelines/transforms.py
+1804
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/samplers/__init__.py
...rmer-main/mmdetection/mmdet/datasets/samplers/__init__.py
+4
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/samplers/distributed_sampler.py
...mdetection/mmdet/datasets/samplers/distributed_sampler.py
+32
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/samplers/group_sampler.py
...main/mmdetection/mmdet/datasets/samplers/group_sampler.py
+143
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/utils.py
...ch/NLP/Conformer-main/mmdetection/mmdet/datasets/utils.py
+100
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/voc.py
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/voc.py
+93
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/wider_face.py
...P/Conformer-main/mmdetection/mmdet/datasets/wider_face.py
+51
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/xml_style.py
...LP/Conformer-main/mmdetection/mmdet/datasets/xml_style.py
+169
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/models/__init__.py
...h/NLP/Conformer-main/mmdetection/mmdet/models/__init__.py
+16
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/models/backbones/Conformer.py
...rmer-main/mmdetection/mmdet/models/backbones/Conformer.py
+574
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/models/backbones/__init__.py
...ormer-main/mmdetection/mmdet/models/backbones/__init__.py
+18
-0
PyTorch/NLP/Conformer-main/mmdetection/mmdet/models/backbones/darknet.py
...former-main/mmdetection/mmdet/models/backbones/darknet.py
+199
-0
No files found.
Too many changes to show.
To preserve performance only
317 of 317+
files are displayed.
Plain diff
Email patch
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/lvis.py
0 → 100644
View file @
142dcf29
import
itertools
import
logging
import
os.path
as
osp
import
tempfile
from
collections
import
OrderedDict
import
numpy
as
np
from
mmcv.utils
import
print_log
from
terminaltables
import
AsciiTable
from
.builder
import
DATASETS
from
.coco
import
CocoDataset
@
DATASETS
.
register_module
()
class
LVISV05Dataset
(
CocoDataset
):
CLASSES
=
(
'acorn'
,
'aerosol_can'
,
'air_conditioner'
,
'airplane'
,
'alarm_clock'
,
'alcohol'
,
'alligator'
,
'almond'
,
'ambulance'
,
'amplifier'
,
'anklet'
,
'antenna'
,
'apple'
,
'apple_juice'
,
'applesauce'
,
'apricot'
,
'apron'
,
'aquarium'
,
'armband'
,
'armchair'
,
'armoire'
,
'armor'
,
'artichoke'
,
'trash_can'
,
'ashtray'
,
'asparagus'
,
'atomizer'
,
'avocado'
,
'award'
,
'awning'
,
'ax'
,
'baby_buggy'
,
'basketball_backboard'
,
'backpack'
,
'handbag'
,
'suitcase'
,
'bagel'
,
'bagpipe'
,
'baguet'
,
'bait'
,
'ball'
,
'ballet_skirt'
,
'balloon'
,
'bamboo'
,
'banana'
,
'Band_Aid'
,
'bandage'
,
'bandanna'
,
'banjo'
,
'banner'
,
'barbell'
,
'barge'
,
'barrel'
,
'barrette'
,
'barrow'
,
'baseball_base'
,
'baseball'
,
'baseball_bat'
,
'baseball_cap'
,
'baseball_glove'
,
'basket'
,
'basketball_hoop'
,
'basketball'
,
'bass_horn'
,
'bat_(animal)'
,
'bath_mat'
,
'bath_towel'
,
'bathrobe'
,
'bathtub'
,
'batter_(food)'
,
'battery'
,
'beachball'
,
'bead'
,
'beaker'
,
'bean_curd'
,
'beanbag'
,
'beanie'
,
'bear'
,
'bed'
,
'bedspread'
,
'cow'
,
'beef_(food)'
,
'beeper'
,
'beer_bottle'
,
'beer_can'
,
'beetle'
,
'bell'
,
'bell_pepper'
,
'belt'
,
'belt_buckle'
,
'bench'
,
'beret'
,
'bib'
,
'Bible'
,
'bicycle'
,
'visor'
,
'binder'
,
'binoculars'
,
'bird'
,
'birdfeeder'
,
'birdbath'
,
'birdcage'
,
'birdhouse'
,
'birthday_cake'
,
'birthday_card'
,
'biscuit_(bread)'
,
'pirate_flag'
,
'black_sheep'
,
'blackboard'
,
'blanket'
,
'blazer'
,
'blender'
,
'blimp'
,
'blinker'
,
'blueberry'
,
'boar'
,
'gameboard'
,
'boat'
,
'bobbin'
,
'bobby_pin'
,
'boiled_egg'
,
'bolo_tie'
,
'deadbolt'
,
'bolt'
,
'bonnet'
,
'book'
,
'book_bag'
,
'bookcase'
,
'booklet'
,
'bookmark'
,
'boom_microphone'
,
'boot'
,
'bottle'
,
'bottle_opener'
,
'bouquet'
,
'bow_(weapon)'
,
'bow_(decorative_ribbons)'
,
'bow-tie'
,
'bowl'
,
'pipe_bowl'
,
'bowler_hat'
,
'bowling_ball'
,
'bowling_pin'
,
'boxing_glove'
,
'suspenders'
,
'bracelet'
,
'brass_plaque'
,
'brassiere'
,
'bread-bin'
,
'breechcloth'
,
'bridal_gown'
,
'briefcase'
,
'bristle_brush'
,
'broccoli'
,
'broach'
,
'broom'
,
'brownie'
,
'brussels_sprouts'
,
'bubble_gum'
,
'bucket'
,
'horse_buggy'
,
'bull'
,
'bulldog'
,
'bulldozer'
,
'bullet_train'
,
'bulletin_board'
,
'bulletproof_vest'
,
'bullhorn'
,
'corned_beef'
,
'bun'
,
'bunk_bed'
,
'buoy'
,
'burrito'
,
'bus_(vehicle)'
,
'business_card'
,
'butcher_knife'
,
'butter'
,
'butterfly'
,
'button'
,
'cab_(taxi)'
,
'cabana'
,
'cabin_car'
,
'cabinet'
,
'locker'
,
'cake'
,
'calculator'
,
'calendar'
,
'calf'
,
'camcorder'
,
'camel'
,
'camera'
,
'camera_lens'
,
'camper_(vehicle)'
,
'can'
,
'can_opener'
,
'candelabrum'
,
'candle'
,
'candle_holder'
,
'candy_bar'
,
'candy_cane'
,
'walking_cane'
,
'canister'
,
'cannon'
,
'canoe'
,
'cantaloup'
,
'canteen'
,
'cap_(headwear)'
,
'bottle_cap'
,
'cape'
,
'cappuccino'
,
'car_(automobile)'
,
'railcar_(part_of_a_train)'
,
'elevator_car'
,
'car_battery'
,
'identity_card'
,
'card'
,
'cardigan'
,
'cargo_ship'
,
'carnation'
,
'horse_carriage'
,
'carrot'
,
'tote_bag'
,
'cart'
,
'carton'
,
'cash_register'
,
'casserole'
,
'cassette'
,
'cast'
,
'cat'
,
'cauliflower'
,
'caviar'
,
'cayenne_(spice)'
,
'CD_player'
,
'celery'
,
'cellular_telephone'
,
'chain_mail'
,
'chair'
,
'chaise_longue'
,
'champagne'
,
'chandelier'
,
'chap'
,
'checkbook'
,
'checkerboard'
,
'cherry'
,
'chessboard'
,
'chest_of_drawers_(furniture)'
,
'chicken_(animal)'
,
'chicken_wire'
,
'chickpea'
,
'Chihuahua'
,
'chili_(vegetable)'
,
'chime'
,
'chinaware'
,
'crisp_(potato_chip)'
,
'poker_chip'
,
'chocolate_bar'
,
'chocolate_cake'
,
'chocolate_milk'
,
'chocolate_mousse'
,
'choker'
,
'chopping_board'
,
'chopstick'
,
'Christmas_tree'
,
'slide'
,
'cider'
,
'cigar_box'
,
'cigarette'
,
'cigarette_case'
,
'cistern'
,
'clarinet'
,
'clasp'
,
'cleansing_agent'
,
'clementine'
,
'clip'
,
'clipboard'
,
'clock'
,
'clock_tower'
,
'clothes_hamper'
,
'clothespin'
,
'clutch_bag'
,
'coaster'
,
'coat'
,
'coat_hanger'
,
'coatrack'
,
'cock'
,
'coconut'
,
'coffee_filter'
,
'coffee_maker'
,
'coffee_table'
,
'coffeepot'
,
'coil'
,
'coin'
,
'colander'
,
'coleslaw'
,
'coloring_material'
,
'combination_lock'
,
'pacifier'
,
'comic_book'
,
'computer_keyboard'
,
'concrete_mixer'
,
'cone'
,
'control'
,
'convertible_(automobile)'
,
'sofa_bed'
,
'cookie'
,
'cookie_jar'
,
'cooking_utensil'
,
'cooler_(for_food)'
,
'cork_(bottle_plug)'
,
'corkboard'
,
'corkscrew'
,
'edible_corn'
,
'cornbread'
,
'cornet'
,
'cornice'
,
'cornmeal'
,
'corset'
,
'romaine_lettuce'
,
'costume'
,
'cougar'
,
'coverall'
,
'cowbell'
,
'cowboy_hat'
,
'crab_(animal)'
,
'cracker'
,
'crape'
,
'crate'
,
'crayon'
,
'cream_pitcher'
,
'credit_card'
,
'crescent_roll'
,
'crib'
,
'crock_pot'
,
'crossbar'
,
'crouton'
,
'crow'
,
'crown'
,
'crucifix'
,
'cruise_ship'
,
'police_cruiser'
,
'crumb'
,
'crutch'
,
'cub_(animal)'
,
'cube'
,
'cucumber'
,
'cufflink'
,
'cup'
,
'trophy_cup'
,
'cupcake'
,
'hair_curler'
,
'curling_iron'
,
'curtain'
,
'cushion'
,
'custard'
,
'cutting_tool'
,
'cylinder'
,
'cymbal'
,
'dachshund'
,
'dagger'
,
'dartboard'
,
'date_(fruit)'
,
'deck_chair'
,
'deer'
,
'dental_floss'
,
'desk'
,
'detergent'
,
'diaper'
,
'diary'
,
'die'
,
'dinghy'
,
'dining_table'
,
'tux'
,
'dish'
,
'dish_antenna'
,
'dishrag'
,
'dishtowel'
,
'dishwasher'
,
'dishwasher_detergent'
,
'diskette'
,
'dispenser'
,
'Dixie_cup'
,
'dog'
,
'dog_collar'
,
'doll'
,
'dollar'
,
'dolphin'
,
'domestic_ass'
,
'eye_mask'
,
'doorbell'
,
'doorknob'
,
'doormat'
,
'doughnut'
,
'dove'
,
'dragonfly'
,
'drawer'
,
'underdrawers'
,
'dress'
,
'dress_hat'
,
'dress_suit'
,
'dresser'
,
'drill'
,
'drinking_fountain'
,
'drone'
,
'dropper'
,
'drum_(musical_instrument)'
,
'drumstick'
,
'duck'
,
'duckling'
,
'duct_tape'
,
'duffel_bag'
,
'dumbbell'
,
'dumpster'
,
'dustpan'
,
'Dutch_oven'
,
'eagle'
,
'earphone'
,
'earplug'
,
'earring'
,
'easel'
,
'eclair'
,
'eel'
,
'egg'
,
'egg_roll'
,
'egg_yolk'
,
'eggbeater'
,
'eggplant'
,
'electric_chair'
,
'refrigerator'
,
'elephant'
,
'elk'
,
'envelope'
,
'eraser'
,
'escargot'
,
'eyepatch'
,
'falcon'
,
'fan'
,
'faucet'
,
'fedora'
,
'ferret'
,
'Ferris_wheel'
,
'ferry'
,
'fig_(fruit)'
,
'fighter_jet'
,
'figurine'
,
'file_cabinet'
,
'file_(tool)'
,
'fire_alarm'
,
'fire_engine'
,
'fire_extinguisher'
,
'fire_hose'
,
'fireplace'
,
'fireplug'
,
'fish'
,
'fish_(food)'
,
'fishbowl'
,
'fishing_boat'
,
'fishing_rod'
,
'flag'
,
'flagpole'
,
'flamingo'
,
'flannel'
,
'flash'
,
'flashlight'
,
'fleece'
,
'flip-flop_(sandal)'
,
'flipper_(footwear)'
,
'flower_arrangement'
,
'flute_glass'
,
'foal'
,
'folding_chair'
,
'food_processor'
,
'football_(American)'
,
'football_helmet'
,
'footstool'
,
'fork'
,
'forklift'
,
'freight_car'
,
'French_toast'
,
'freshener'
,
'frisbee'
,
'frog'
,
'fruit_juice'
,
'fruit_salad'
,
'frying_pan'
,
'fudge'
,
'funnel'
,
'futon'
,
'gag'
,
'garbage'
,
'garbage_truck'
,
'garden_hose'
,
'gargle'
,
'gargoyle'
,
'garlic'
,
'gasmask'
,
'gazelle'
,
'gelatin'
,
'gemstone'
,
'giant_panda'
,
'gift_wrap'
,
'ginger'
,
'giraffe'
,
'cincture'
,
'glass_(drink_container)'
,
'globe'
,
'glove'
,
'goat'
,
'goggles'
,
'goldfish'
,
'golf_club'
,
'golfcart'
,
'gondola_(boat)'
,
'goose'
,
'gorilla'
,
'gourd'
,
'surgical_gown'
,
'grape'
,
'grasshopper'
,
'grater'
,
'gravestone'
,
'gravy_boat'
,
'green_bean'
,
'green_onion'
,
'griddle'
,
'grillroom'
,
'grinder_(tool)'
,
'grits'
,
'grizzly'
,
'grocery_bag'
,
'guacamole'
,
'guitar'
,
'gull'
,
'gun'
,
'hair_spray'
,
'hairbrush'
,
'hairnet'
,
'hairpin'
,
'ham'
,
'hamburger'
,
'hammer'
,
'hammock'
,
'hamper'
,
'hamster'
,
'hair_dryer'
,
'hand_glass'
,
'hand_towel'
,
'handcart'
,
'handcuff'
,
'handkerchief'
,
'handle'
,
'handsaw'
,
'hardback_book'
,
'harmonium'
,
'hat'
,
'hatbox'
,
'hatch'
,
'veil'
,
'headband'
,
'headboard'
,
'headlight'
,
'headscarf'
,
'headset'
,
'headstall_(for_horses)'
,
'hearing_aid'
,
'heart'
,
'heater'
,
'helicopter'
,
'helmet'
,
'heron'
,
'highchair'
,
'hinge'
,
'hippopotamus'
,
'hockey_stick'
,
'hog'
,
'home_plate_(baseball)'
,
'honey'
,
'fume_hood'
,
'hook'
,
'horse'
,
'hose'
,
'hot-air_balloon'
,
'hotplate'
,
'hot_sauce'
,
'hourglass'
,
'houseboat'
,
'hummingbird'
,
'hummus'
,
'polar_bear'
,
'icecream'
,
'popsicle'
,
'ice_maker'
,
'ice_pack'
,
'ice_skate'
,
'ice_tea'
,
'igniter'
,
'incense'
,
'inhaler'
,
'iPod'
,
'iron_(for_clothing)'
,
'ironing_board'
,
'jacket'
,
'jam'
,
'jean'
,
'jeep'
,
'jelly_bean'
,
'jersey'
,
'jet_plane'
,
'jewelry'
,
'joystick'
,
'jumpsuit'
,
'kayak'
,
'keg'
,
'kennel'
,
'kettle'
,
'key'
,
'keycard'
,
'kilt'
,
'kimono'
,
'kitchen_sink'
,
'kitchen_table'
,
'kite'
,
'kitten'
,
'kiwi_fruit'
,
'knee_pad'
,
'knife'
,
'knight_(chess_piece)'
,
'knitting_needle'
,
'knob'
,
'knocker_(on_a_door)'
,
'koala'
,
'lab_coat'
,
'ladder'
,
'ladle'
,
'ladybug'
,
'lamb_(animal)'
,
'lamb-chop'
,
'lamp'
,
'lamppost'
,
'lampshade'
,
'lantern'
,
'lanyard'
,
'laptop_computer'
,
'lasagna'
,
'latch'
,
'lawn_mower'
,
'leather'
,
'legging_(clothing)'
,
'Lego'
,
'lemon'
,
'lemonade'
,
'lettuce'
,
'license_plate'
,
'life_buoy'
,
'life_jacket'
,
'lightbulb'
,
'lightning_rod'
,
'lime'
,
'limousine'
,
'linen_paper'
,
'lion'
,
'lip_balm'
,
'lipstick'
,
'liquor'
,
'lizard'
,
'Loafer_(type_of_shoe)'
,
'log'
,
'lollipop'
,
'lotion'
,
'speaker_(stero_equipment)'
,
'loveseat'
,
'machine_gun'
,
'magazine'
,
'magnet'
,
'mail_slot'
,
'mailbox_(at_home)'
,
'mallet'
,
'mammoth'
,
'mandarin_orange'
,
'manger'
,
'manhole'
,
'map'
,
'marker'
,
'martini'
,
'mascot'
,
'mashed_potato'
,
'masher'
,
'mask'
,
'mast'
,
'mat_(gym_equipment)'
,
'matchbox'
,
'mattress'
,
'measuring_cup'
,
'measuring_stick'
,
'meatball'
,
'medicine'
,
'melon'
,
'microphone'
,
'microscope'
,
'microwave_oven'
,
'milestone'
,
'milk'
,
'minivan'
,
'mint_candy'
,
'mirror'
,
'mitten'
,
'mixer_(kitchen_tool)'
,
'money'
,
'monitor_(computer_equipment) computer_monitor'
,
'monkey'
,
'motor'
,
'motor_scooter'
,
'motor_vehicle'
,
'motorboat'
,
'motorcycle'
,
'mound_(baseball)'
,
'mouse_(animal_rodent)'
,
'mouse_(computer_equipment)'
,
'mousepad'
,
'muffin'
,
'mug'
,
'mushroom'
,
'music_stool'
,
'musical_instrument'
,
'nailfile'
,
'nameplate'
,
'napkin'
,
'neckerchief'
,
'necklace'
,
'necktie'
,
'needle'
,
'nest'
,
'newsstand'
,
'nightshirt'
,
'nosebag_(for_animals)'
,
'noseband_(for_animals)'
,
'notebook'
,
'notepad'
,
'nut'
,
'nutcracker'
,
'oar'
,
'octopus_(food)'
,
'octopus_(animal)'
,
'oil_lamp'
,
'olive_oil'
,
'omelet'
,
'onion'
,
'orange_(fruit)'
,
'orange_juice'
,
'oregano'
,
'ostrich'
,
'ottoman'
,
'overalls_(clothing)'
,
'owl'
,
'packet'
,
'inkpad'
,
'pad'
,
'paddle'
,
'padlock'
,
'paintbox'
,
'paintbrush'
,
'painting'
,
'pajamas'
,
'palette'
,
'pan_(for_cooking)'
,
'pan_(metal_container)'
,
'pancake'
,
'pantyhose'
,
'papaya'
,
'paperclip'
,
'paper_plate'
,
'paper_towel'
,
'paperback_book'
,
'paperweight'
,
'parachute'
,
'parakeet'
,
'parasail_(sports)'
,
'parchment'
,
'parka'
,
'parking_meter'
,
'parrot'
,
'passenger_car_(part_of_a_train)'
,
'passenger_ship'
,
'passport'
,
'pastry'
,
'patty_(food)'
,
'pea_(food)'
,
'peach'
,
'peanut_butter'
,
'pear'
,
'peeler_(tool_for_fruit_and_vegetables)'
,
'pegboard'
,
'pelican'
,
'pen'
,
'pencil'
,
'pencil_box'
,
'pencil_sharpener'
,
'pendulum'
,
'penguin'
,
'pennant'
,
'penny_(coin)'
,
'pepper'
,
'pepper_mill'
,
'perfume'
,
'persimmon'
,
'baby'
,
'pet'
,
'petfood'
,
'pew_(church_bench)'
,
'phonebook'
,
'phonograph_record'
,
'piano'
,
'pickle'
,
'pickup_truck'
,
'pie'
,
'pigeon'
,
'piggy_bank'
,
'pillow'
,
'pin_(non_jewelry)'
,
'pineapple'
,
'pinecone'
,
'ping-pong_ball'
,
'pinwheel'
,
'tobacco_pipe'
,
'pipe'
,
'pistol'
,
'pita_(bread)'
,
'pitcher_(vessel_for_liquid)'
,
'pitchfork'
,
'pizza'
,
'place_mat'
,
'plate'
,
'platter'
,
'playing_card'
,
'playpen'
,
'pliers'
,
'plow_(farm_equipment)'
,
'pocket_watch'
,
'pocketknife'
,
'poker_(fire_stirring_tool)'
,
'pole'
,
'police_van'
,
'polo_shirt'
,
'poncho'
,
'pony'
,
'pool_table'
,
'pop_(soda)'
,
'portrait'
,
'postbox_(public)'
,
'postcard'
,
'poster'
,
'pot'
,
'flowerpot'
,
'potato'
,
'potholder'
,
'pottery'
,
'pouch'
,
'power_shovel'
,
'prawn'
,
'printer'
,
'projectile_(weapon)'
,
'projector'
,
'propeller'
,
'prune'
,
'pudding'
,
'puffer_(fish)'
,
'puffin'
,
'pug-dog'
,
'pumpkin'
,
'puncher'
,
'puppet'
,
'puppy'
,
'quesadilla'
,
'quiche'
,
'quilt'
,
'rabbit'
,
'race_car'
,
'racket'
,
'radar'
,
'radiator'
,
'radio_receiver'
,
'radish'
,
'raft'
,
'rag_doll'
,
'raincoat'
,
'ram_(animal)'
,
'raspberry'
,
'rat'
,
'razorblade'
,
'reamer_(juicer)'
,
'rearview_mirror'
,
'receipt'
,
'recliner'
,
'record_player'
,
'red_cabbage'
,
'reflector'
,
'remote_control'
,
'rhinoceros'
,
'rib_(food)'
,
'rifle'
,
'ring'
,
'river_boat'
,
'road_map'
,
'robe'
,
'rocking_chair'
,
'roller_skate'
,
'Rollerblade'
,
'rolling_pin'
,
'root_beer'
,
'router_(computer_equipment)'
,
'rubber_band'
,
'runner_(carpet)'
,
'plastic_bag'
,
'saddle_(on_an_animal)'
,
'saddle_blanket'
,
'saddlebag'
,
'safety_pin'
,
'sail'
,
'salad'
,
'salad_plate'
,
'salami'
,
'salmon_(fish)'
,
'salmon_(food)'
,
'salsa'
,
'saltshaker'
,
'sandal_(type_of_shoe)'
,
'sandwich'
,
'satchel'
,
'saucepan'
,
'saucer'
,
'sausage'
,
'sawhorse'
,
'saxophone'
,
'scale_(measuring_instrument)'
,
'scarecrow'
,
'scarf'
,
'school_bus'
,
'scissors'
,
'scoreboard'
,
'scrambled_eggs'
,
'scraper'
,
'scratcher'
,
'screwdriver'
,
'scrubbing_brush'
,
'sculpture'
,
'seabird'
,
'seahorse'
,
'seaplane'
,
'seashell'
,
'seedling'
,
'serving_dish'
,
'sewing_machine'
,
'shaker'
,
'shampoo'
,
'shark'
,
'sharpener'
,
'Sharpie'
,
'shaver_(electric)'
,
'shaving_cream'
,
'shawl'
,
'shears'
,
'sheep'
,
'shepherd_dog'
,
'sherbert'
,
'shield'
,
'shirt'
,
'shoe'
,
'shopping_bag'
,
'shopping_cart'
,
'short_pants'
,
'shot_glass'
,
'shoulder_bag'
,
'shovel'
,
'shower_head'
,
'shower_curtain'
,
'shredder_(for_paper)'
,
'sieve'
,
'signboard'
,
'silo'
,
'sink'
,
'skateboard'
,
'skewer'
,
'ski'
,
'ski_boot'
,
'ski_parka'
,
'ski_pole'
,
'skirt'
,
'sled'
,
'sleeping_bag'
,
'sling_(bandage)'
,
'slipper_(footwear)'
,
'smoothie'
,
'snake'
,
'snowboard'
,
'snowman'
,
'snowmobile'
,
'soap'
,
'soccer_ball'
,
'sock'
,
'soda_fountain'
,
'carbonated_water'
,
'sofa'
,
'softball'
,
'solar_array'
,
'sombrero'
,
'soup'
,
'soup_bowl'
,
'soupspoon'
,
'sour_cream'
,
'soya_milk'
,
'space_shuttle'
,
'sparkler_(fireworks)'
,
'spatula'
,
'spear'
,
'spectacles'
,
'spice_rack'
,
'spider'
,
'sponge'
,
'spoon'
,
'sportswear'
,
'spotlight'
,
'squirrel'
,
'stapler_(stapling_machine)'
,
'starfish'
,
'statue_(sculpture)'
,
'steak_(food)'
,
'steak_knife'
,
'steamer_(kitchen_appliance)'
,
'steering_wheel'
,
'stencil'
,
'stepladder'
,
'step_stool'
,
'stereo_(sound_system)'
,
'stew'
,
'stirrer'
,
'stirrup'
,
'stockings_(leg_wear)'
,
'stool'
,
'stop_sign'
,
'brake_light'
,
'stove'
,
'strainer'
,
'strap'
,
'straw_(for_drinking)'
,
'strawberry'
,
'street_sign'
,
'streetlight'
,
'string_cheese'
,
'stylus'
,
'subwoofer'
,
'sugar_bowl'
,
'sugarcane_(plant)'
,
'suit_(clothing)'
,
'sunflower'
,
'sunglasses'
,
'sunhat'
,
'sunscreen'
,
'surfboard'
,
'sushi'
,
'mop'
,
'sweat_pants'
,
'sweatband'
,
'sweater'
,
'sweatshirt'
,
'sweet_potato'
,
'swimsuit'
,
'sword'
,
'syringe'
,
'Tabasco_sauce'
,
'table-tennis_table'
,
'table'
,
'table_lamp'
,
'tablecloth'
,
'tachometer'
,
'taco'
,
'tag'
,
'taillight'
,
'tambourine'
,
'army_tank'
,
'tank_(storage_vessel)'
,
'tank_top_(clothing)'
,
'tape_(sticky_cloth_or_paper)'
,
'tape_measure'
,
'tapestry'
,
'tarp'
,
'tartan'
,
'tassel'
,
'tea_bag'
,
'teacup'
,
'teakettle'
,
'teapot'
,
'teddy_bear'
,
'telephone'
,
'telephone_booth'
,
'telephone_pole'
,
'telephoto_lens'
,
'television_camera'
,
'television_set'
,
'tennis_ball'
,
'tennis_racket'
,
'tequila'
,
'thermometer'
,
'thermos_bottle'
,
'thermostat'
,
'thimble'
,
'thread'
,
'thumbtack'
,
'tiara'
,
'tiger'
,
'tights_(clothing)'
,
'timer'
,
'tinfoil'
,
'tinsel'
,
'tissue_paper'
,
'toast_(food)'
,
'toaster'
,
'toaster_oven'
,
'toilet'
,
'toilet_tissue'
,
'tomato'
,
'tongs'
,
'toolbox'
,
'toothbrush'
,
'toothpaste'
,
'toothpick'
,
'cover'
,
'tortilla'
,
'tow_truck'
,
'towel'
,
'towel_rack'
,
'toy'
,
'tractor_(farm_equipment)'
,
'traffic_light'
,
'dirt_bike'
,
'trailer_truck'
,
'train_(railroad_vehicle)'
,
'trampoline'
,
'tray'
,
'tree_house'
,
'trench_coat'
,
'triangle_(musical_instrument)'
,
'tricycle'
,
'tripod'
,
'trousers'
,
'truck'
,
'truffle_(chocolate)'
,
'trunk'
,
'vat'
,
'turban'
,
'turkey_(bird)'
,
'turkey_(food)'
,
'turnip'
,
'turtle'
,
'turtleneck_(clothing)'
,
'typewriter'
,
'umbrella'
,
'underwear'
,
'unicycle'
,
'urinal'
,
'urn'
,
'vacuum_cleaner'
,
'valve'
,
'vase'
,
'vending_machine'
,
'vent'
,
'videotape'
,
'vinegar'
,
'violin'
,
'vodka'
,
'volleyball'
,
'vulture'
,
'waffle'
,
'waffle_iron'
,
'wagon'
,
'wagon_wheel'
,
'walking_stick'
,
'wall_clock'
,
'wall_socket'
,
'wallet'
,
'walrus'
,
'wardrobe'
,
'wasabi'
,
'automatic_washer'
,
'watch'
,
'water_bottle'
,
'water_cooler'
,
'water_faucet'
,
'water_filter'
,
'water_heater'
,
'water_jug'
,
'water_gun'
,
'water_scooter'
,
'water_ski'
,
'water_tower'
,
'watering_can'
,
'watermelon'
,
'weathervane'
,
'webcam'
,
'wedding_cake'
,
'wedding_ring'
,
'wet_suit'
,
'wheel'
,
'wheelchair'
,
'whipped_cream'
,
'whiskey'
,
'whistle'
,
'wick'
,
'wig'
,
'wind_chime'
,
'windmill'
,
'window_box_(for_plants)'
,
'windshield_wiper'
,
'windsock'
,
'wine_bottle'
,
'wine_bucket'
,
'wineglass'
,
'wing_chair'
,
'blinder_(for_horses)'
,
'wok'
,
'wolf'
,
'wooden_spoon'
,
'wreath'
,
'wrench'
,
'wristband'
,
'wristlet'
,
'yacht'
,
'yak'
,
'yogurt'
,
'yoke_(animal_equipment)'
,
'zebra'
,
'zucchini'
)
def
load_annotations
(
self
,
ann_file
):
"""Load annotation from lvis style annotation file.
Args:
ann_file (str): Path of annotation file.
Returns:
list[dict]: Annotation info from LVIS api.
"""
try
:
import
lvis
assert
lvis
.
__version__
>=
'10.5.3'
from
lvis
import
LVIS
except
AssertionError
:
raise
AssertionError
(
'Incompatible version of lvis is installed. '
'Run pip uninstall lvis first. Then run pip '
'install mmlvis to install open-mmlab forked '
'lvis. '
)
except
ImportError
:
raise
ImportError
(
'Package lvis is not installed. Please run pip '
'install mmlvis to install open-mmlab forked '
'lvis.'
)
self
.
coco
=
LVIS
(
ann_file
)
self
.
cat_ids
=
self
.
coco
.
get_cat_ids
()
self
.
cat2label
=
{
cat_id
:
i
for
i
,
cat_id
in
enumerate
(
self
.
cat_ids
)}
self
.
img_ids
=
self
.
coco
.
get_img_ids
()
data_infos
=
[]
for
i
in
self
.
img_ids
:
info
=
self
.
coco
.
load_imgs
([
i
])[
0
]
if
info
[
'file_name'
].
startswith
(
'COCO'
):
# Convert form the COCO 2014 file naming convention of
# COCO_[train/val/test]2014_000000000000.jpg to the 2017
# naming convention of 000000000000.jpg
# (LVIS v1 will fix this naming issue)
info
[
'filename'
]
=
info
[
'file_name'
][
-
16
:]
else
:
info
[
'filename'
]
=
info
[
'file_name'
]
data_infos
.
append
(
info
)
return
data_infos
def
evaluate
(
self
,
results
,
metric
=
'bbox'
,
logger
=
None
,
jsonfile_prefix
=
None
,
classwise
=
False
,
proposal_nums
=
(
100
,
300
,
1000
),
iou_thrs
=
np
.
arange
(
0.5
,
0.96
,
0.05
)):
"""Evaluation in LVIS protocol.
Args:
results (list[list | tuple]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. Options are
'bbox', 'segm', 'proposal', 'proposal_fast'.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
jsonfile_prefix (str | None):
classwise (bool): Whether to evaluating the AP for each class.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.
Default: (100, 300, 1000).
iou_thrs (Sequence[float]): IoU threshold used for evaluating
recalls. If set to a list, the average recall of all IoUs will
also be computed. Default: 0.5.
Returns:
dict[str, float]: LVIS style metrics.
"""
try
:
import
lvis
assert
lvis
.
__version__
>=
'10.5.3'
from
lvis
import
LVISResults
,
LVISEval
except
AssertionError
:
raise
AssertionError
(
'Incompatible version of lvis is installed. '
'Run pip uninstall lvis first. Then run pip '
'install mmlvis to install open-mmlab forked '
'lvis. '
)
except
ImportError
:
raise
ImportError
(
'Package lvis is not installed. Please run pip '
'install mmlvis to install open-mmlab forked '
'lvis.'
)
assert
isinstance
(
results
,
list
),
'results must be a list'
assert
len
(
results
)
==
len
(
self
),
(
'The length of results is not equal to the dataset len: {} != {}'
.
format
(
len
(
results
),
len
(
self
)))
metrics
=
metric
if
isinstance
(
metric
,
list
)
else
[
metric
]
allowed_metrics
=
[
'bbox'
,
'segm'
,
'proposal'
,
'proposal_fast'
]
for
metric
in
metrics
:
if
metric
not
in
allowed_metrics
:
raise
KeyError
(
'metric {} is not supported'
.
format
(
metric
))
if
jsonfile_prefix
is
None
:
tmp_dir
=
tempfile
.
TemporaryDirectory
()
jsonfile_prefix
=
osp
.
join
(
tmp_dir
.
name
,
'results'
)
else
:
tmp_dir
=
None
result_files
=
self
.
results2json
(
results
,
jsonfile_prefix
)
eval_results
=
OrderedDict
()
# get original api
lvis_gt
=
self
.
coco
for
metric
in
metrics
:
msg
=
'Evaluating {}...'
.
format
(
metric
)
if
logger
is
None
:
msg
=
'
\n
'
+
msg
print_log
(
msg
,
logger
=
logger
)
if
metric
==
'proposal_fast'
:
ar
=
self
.
fast_eval_recall
(
results
,
proposal_nums
,
iou_thrs
,
logger
=
'silent'
)
log_msg
=
[]
for
i
,
num
in
enumerate
(
proposal_nums
):
eval_results
[
'AR@{}'
.
format
(
num
)]
=
ar
[
i
]
log_msg
.
append
(
'
\n
AR@{}
\t
{:.4f}'
.
format
(
num
,
ar
[
i
]))
log_msg
=
''
.
join
(
log_msg
)
print_log
(
log_msg
,
logger
=
logger
)
continue
if
metric
not
in
result_files
:
raise
KeyError
(
'{} is not in results'
.
format
(
metric
))
try
:
lvis_dt
=
LVISResults
(
lvis_gt
,
result_files
[
metric
])
except
IndexError
:
print_log
(
'The testing results of the whole dataset is empty.'
,
logger
=
logger
,
level
=
logging
.
ERROR
)
break
iou_type
=
'bbox'
if
metric
==
'proposal'
else
metric
lvis_eval
=
LVISEval
(
lvis_gt
,
lvis_dt
,
iou_type
)
lvis_eval
.
params
.
imgIds
=
self
.
img_ids
if
metric
==
'proposal'
:
lvis_eval
.
params
.
useCats
=
0
lvis_eval
.
params
.
maxDets
=
list
(
proposal_nums
)
lvis_eval
.
evaluate
()
lvis_eval
.
accumulate
()
lvis_eval
.
summarize
()
for
k
,
v
in
lvis_eval
.
get_results
().
items
():
if
k
.
startswith
(
'AR'
):
val
=
float
(
'{:.3f}'
.
format
(
float
(
v
)))
eval_results
[
k
]
=
val
else
:
lvis_eval
.
evaluate
()
lvis_eval
.
accumulate
()
lvis_eval
.
summarize
()
lvis_results
=
lvis_eval
.
get_results
()
if
classwise
:
# Compute per-category AP
# Compute per-category AP
# from https://github.com/facebookresearch/detectron2/
precisions
=
lvis_eval
.
eval
[
'precision'
]
# precision: (iou, recall, cls, area range, max dets)
assert
len
(
self
.
cat_ids
)
==
precisions
.
shape
[
2
]
results_per_category
=
[]
for
idx
,
catId
in
enumerate
(
self
.
cat_ids
):
# area range index 0: all area ranges
# max dets index -1: typically 100 per image
nm
=
self
.
coco
.
load_cats
(
catId
)[
0
]
precision
=
precisions
[:,
:,
idx
,
0
,
-
1
]
precision
=
precision
[
precision
>
-
1
]
if
precision
.
size
:
ap
=
np
.
mean
(
precision
)
else
:
ap
=
float
(
'nan'
)
results_per_category
.
append
(
(
f
'
{
nm
[
"name"
]
}
'
,
f
'
{
float
(
ap
):
0.3
f
}
'
))
num_columns
=
min
(
6
,
len
(
results_per_category
)
*
2
)
results_flatten
=
list
(
itertools
.
chain
(
*
results_per_category
))
headers
=
[
'category'
,
'AP'
]
*
(
num_columns
//
2
)
results_2d
=
itertools
.
zip_longest
(
*
[
results_flatten
[
i
::
num_columns
]
for
i
in
range
(
num_columns
)
])
table_data
=
[
headers
]
table_data
+=
[
result
for
result
in
results_2d
]
table
=
AsciiTable
(
table_data
)
print_log
(
'
\n
'
+
table
.
table
,
logger
=
logger
)
for
k
,
v
in
lvis_results
.
items
():
if
k
.
startswith
(
'AP'
):
key
=
'{}_{}'
.
format
(
metric
,
k
)
val
=
float
(
'{:.3f}'
.
format
(
float
(
v
)))
eval_results
[
key
]
=
val
ap_summary
=
' '
.
join
([
'{}:{:.3f}'
.
format
(
k
,
float
(
v
))
for
k
,
v
in
lvis_results
.
items
()
if
k
.
startswith
(
'AP'
)
])
eval_results
[
'{}_mAP_copypaste'
.
format
(
metric
)]
=
ap_summary
lvis_eval
.
print_results
()
if
tmp_dir
is
not
None
:
tmp_dir
.
cleanup
()
return
eval_results
LVISDataset
=
LVISV05Dataset
DATASETS
.
register_module
(
name
=
'LVISDataset'
,
module
=
LVISDataset
)
@
DATASETS
.
register_module
()
class
LVISV1Dataset
(
LVISDataset
):
CLASSES
=
(
'aerosol_can'
,
'air_conditioner'
,
'airplane'
,
'alarm_clock'
,
'alcohol'
,
'alligator'
,
'almond'
,
'ambulance'
,
'amplifier'
,
'anklet'
,
'antenna'
,
'apple'
,
'applesauce'
,
'apricot'
,
'apron'
,
'aquarium'
,
'arctic_(type_of_shoe)'
,
'armband'
,
'armchair'
,
'armoire'
,
'armor'
,
'artichoke'
,
'trash_can'
,
'ashtray'
,
'asparagus'
,
'atomizer'
,
'avocado'
,
'award'
,
'awning'
,
'ax'
,
'baboon'
,
'baby_buggy'
,
'basketball_backboard'
,
'backpack'
,
'handbag'
,
'suitcase'
,
'bagel'
,
'bagpipe'
,
'baguet'
,
'bait'
,
'ball'
,
'ballet_skirt'
,
'balloon'
,
'bamboo'
,
'banana'
,
'Band_Aid'
,
'bandage'
,
'bandanna'
,
'banjo'
,
'banner'
,
'barbell'
,
'barge'
,
'barrel'
,
'barrette'
,
'barrow'
,
'baseball_base'
,
'baseball'
,
'baseball_bat'
,
'baseball_cap'
,
'baseball_glove'
,
'basket'
,
'basketball'
,
'bass_horn'
,
'bat_(animal)'
,
'bath_mat'
,
'bath_towel'
,
'bathrobe'
,
'bathtub'
,
'batter_(food)'
,
'battery'
,
'beachball'
,
'bead'
,
'bean_curd'
,
'beanbag'
,
'beanie'
,
'bear'
,
'bed'
,
'bedpan'
,
'bedspread'
,
'cow'
,
'beef_(food)'
,
'beeper'
,
'beer_bottle'
,
'beer_can'
,
'beetle'
,
'bell'
,
'bell_pepper'
,
'belt'
,
'belt_buckle'
,
'bench'
,
'beret'
,
'bib'
,
'Bible'
,
'bicycle'
,
'visor'
,
'billboard'
,
'binder'
,
'binoculars'
,
'bird'
,
'birdfeeder'
,
'birdbath'
,
'birdcage'
,
'birdhouse'
,
'birthday_cake'
,
'birthday_card'
,
'pirate_flag'
,
'black_sheep'
,
'blackberry'
,
'blackboard'
,
'blanket'
,
'blazer'
,
'blender'
,
'blimp'
,
'blinker'
,
'blouse'
,
'blueberry'
,
'gameboard'
,
'boat'
,
'bob'
,
'bobbin'
,
'bobby_pin'
,
'boiled_egg'
,
'bolo_tie'
,
'deadbolt'
,
'bolt'
,
'bonnet'
,
'book'
,
'bookcase'
,
'booklet'
,
'bookmark'
,
'boom_microphone'
,
'boot'
,
'bottle'
,
'bottle_opener'
,
'bouquet'
,
'bow_(weapon)'
,
'bow_(decorative_ribbons)'
,
'bow-tie'
,
'bowl'
,
'pipe_bowl'
,
'bowler_hat'
,
'bowling_ball'
,
'box'
,
'boxing_glove'
,
'suspenders'
,
'bracelet'
,
'brass_plaque'
,
'brassiere'
,
'bread-bin'
,
'bread'
,
'breechcloth'
,
'bridal_gown'
,
'briefcase'
,
'broccoli'
,
'broach'
,
'broom'
,
'brownie'
,
'brussels_sprouts'
,
'bubble_gum'
,
'bucket'
,
'horse_buggy'
,
'bull'
,
'bulldog'
,
'bulldozer'
,
'bullet_train'
,
'bulletin_board'
,
'bulletproof_vest'
,
'bullhorn'
,
'bun'
,
'bunk_bed'
,
'buoy'
,
'burrito'
,
'bus_(vehicle)'
,
'business_card'
,
'butter'
,
'butterfly'
,
'button'
,
'cab_(taxi)'
,
'cabana'
,
'cabin_car'
,
'cabinet'
,
'locker'
,
'cake'
,
'calculator'
,
'calendar'
,
'calf'
,
'camcorder'
,
'camel'
,
'camera'
,
'camera_lens'
,
'camper_(vehicle)'
,
'can'
,
'can_opener'
,
'candle'
,
'candle_holder'
,
'candy_bar'
,
'candy_cane'
,
'walking_cane'
,
'canister'
,
'canoe'
,
'cantaloup'
,
'canteen'
,
'cap_(headwear)'
,
'bottle_cap'
,
'cape'
,
'cappuccino'
,
'car_(automobile)'
,
'railcar_(part_of_a_train)'
,
'elevator_car'
,
'car_battery'
,
'identity_card'
,
'card'
,
'cardigan'
,
'cargo_ship'
,
'carnation'
,
'horse_carriage'
,
'carrot'
,
'tote_bag'
,
'cart'
,
'carton'
,
'cash_register'
,
'casserole'
,
'cassette'
,
'cast'
,
'cat'
,
'cauliflower'
,
'cayenne_(spice)'
,
'CD_player'
,
'celery'
,
'cellular_telephone'
,
'chain_mail'
,
'chair'
,
'chaise_longue'
,
'chalice'
,
'chandelier'
,
'chap'
,
'checkbook'
,
'checkerboard'
,
'cherry'
,
'chessboard'
,
'chicken_(animal)'
,
'chickpea'
,
'chili_(vegetable)'
,
'chime'
,
'chinaware'
,
'crisp_(potato_chip)'
,
'poker_chip'
,
'chocolate_bar'
,
'chocolate_cake'
,
'chocolate_milk'
,
'chocolate_mousse'
,
'choker'
,
'chopping_board'
,
'chopstick'
,
'Christmas_tree'
,
'slide'
,
'cider'
,
'cigar_box'
,
'cigarette'
,
'cigarette_case'
,
'cistern'
,
'clarinet'
,
'clasp'
,
'cleansing_agent'
,
'cleat_(for_securing_rope)'
,
'clementine'
,
'clip'
,
'clipboard'
,
'clippers_(for_plants)'
,
'cloak'
,
'clock'
,
'clock_tower'
,
'clothes_hamper'
,
'clothespin'
,
'clutch_bag'
,
'coaster'
,
'coat'
,
'coat_hanger'
,
'coatrack'
,
'cock'
,
'cockroach'
,
'cocoa_(beverage)'
,
'coconut'
,
'coffee_maker'
,
'coffee_table'
,
'coffeepot'
,
'coil'
,
'coin'
,
'colander'
,
'coleslaw'
,
'coloring_material'
,
'combination_lock'
,
'pacifier'
,
'comic_book'
,
'compass'
,
'computer_keyboard'
,
'condiment'
,
'cone'
,
'control'
,
'convertible_(automobile)'
,
'sofa_bed'
,
'cooker'
,
'cookie'
,
'cooking_utensil'
,
'cooler_(for_food)'
,
'cork_(bottle_plug)'
,
'corkboard'
,
'corkscrew'
,
'edible_corn'
,
'cornbread'
,
'cornet'
,
'cornice'
,
'cornmeal'
,
'corset'
,
'costume'
,
'cougar'
,
'coverall'
,
'cowbell'
,
'cowboy_hat'
,
'crab_(animal)'
,
'crabmeat'
,
'cracker'
,
'crape'
,
'crate'
,
'crayon'
,
'cream_pitcher'
,
'crescent_roll'
,
'crib'
,
'crock_pot'
,
'crossbar'
,
'crouton'
,
'crow'
,
'crowbar'
,
'crown'
,
'crucifix'
,
'cruise_ship'
,
'police_cruiser'
,
'crumb'
,
'crutch'
,
'cub_(animal)'
,
'cube'
,
'cucumber'
,
'cufflink'
,
'cup'
,
'trophy_cup'
,
'cupboard'
,
'cupcake'
,
'hair_curler'
,
'curling_iron'
,
'curtain'
,
'cushion'
,
'cylinder'
,
'cymbal'
,
'dagger'
,
'dalmatian'
,
'dartboard'
,
'date_(fruit)'
,
'deck_chair'
,
'deer'
,
'dental_floss'
,
'desk'
,
'detergent'
,
'diaper'
,
'diary'
,
'die'
,
'dinghy'
,
'dining_table'
,
'tux'
,
'dish'
,
'dish_antenna'
,
'dishrag'
,
'dishtowel'
,
'dishwasher'
,
'dishwasher_detergent'
,
'dispenser'
,
'diving_board'
,
'Dixie_cup'
,
'dog'
,
'dog_collar'
,
'doll'
,
'dollar'
,
'dollhouse'
,
'dolphin'
,
'domestic_ass'
,
'doorknob'
,
'doormat'
,
'doughnut'
,
'dove'
,
'dragonfly'
,
'drawer'
,
'underdrawers'
,
'dress'
,
'dress_hat'
,
'dress_suit'
,
'dresser'
,
'drill'
,
'drone'
,
'dropper'
,
'drum_(musical_instrument)'
,
'drumstick'
,
'duck'
,
'duckling'
,
'duct_tape'
,
'duffel_bag'
,
'dumbbell'
,
'dumpster'
,
'dustpan'
,
'eagle'
,
'earphone'
,
'earplug'
,
'earring'
,
'easel'
,
'eclair'
,
'eel'
,
'egg'
,
'egg_roll'
,
'egg_yolk'
,
'eggbeater'
,
'eggplant'
,
'electric_chair'
,
'refrigerator'
,
'elephant'
,
'elk'
,
'envelope'
,
'eraser'
,
'escargot'
,
'eyepatch'
,
'falcon'
,
'fan'
,
'faucet'
,
'fedora'
,
'ferret'
,
'Ferris_wheel'
,
'ferry'
,
'fig_(fruit)'
,
'fighter_jet'
,
'figurine'
,
'file_cabinet'
,
'file_(tool)'
,
'fire_alarm'
,
'fire_engine'
,
'fire_extinguisher'
,
'fire_hose'
,
'fireplace'
,
'fireplug'
,
'first-aid_kit'
,
'fish'
,
'fish_(food)'
,
'fishbowl'
,
'fishing_rod'
,
'flag'
,
'flagpole'
,
'flamingo'
,
'flannel'
,
'flap'
,
'flash'
,
'flashlight'
,
'fleece'
,
'flip-flop_(sandal)'
,
'flipper_(footwear)'
,
'flower_arrangement'
,
'flute_glass'
,
'foal'
,
'folding_chair'
,
'food_processor'
,
'football_(American)'
,
'football_helmet'
,
'footstool'
,
'fork'
,
'forklift'
,
'freight_car'
,
'French_toast'
,
'freshener'
,
'frisbee'
,
'frog'
,
'fruit_juice'
,
'frying_pan'
,
'fudge'
,
'funnel'
,
'futon'
,
'gag'
,
'garbage'
,
'garbage_truck'
,
'garden_hose'
,
'gargle'
,
'gargoyle'
,
'garlic'
,
'gasmask'
,
'gazelle'
,
'gelatin'
,
'gemstone'
,
'generator'
,
'giant_panda'
,
'gift_wrap'
,
'ginger'
,
'giraffe'
,
'cincture'
,
'glass_(drink_container)'
,
'globe'
,
'glove'
,
'goat'
,
'goggles'
,
'goldfish'
,
'golf_club'
,
'golfcart'
,
'gondola_(boat)'
,
'goose'
,
'gorilla'
,
'gourd'
,
'grape'
,
'grater'
,
'gravestone'
,
'gravy_boat'
,
'green_bean'
,
'green_onion'
,
'griddle'
,
'grill'
,
'grits'
,
'grizzly'
,
'grocery_bag'
,
'guitar'
,
'gull'
,
'gun'
,
'hairbrush'
,
'hairnet'
,
'hairpin'
,
'halter_top'
,
'ham'
,
'hamburger'
,
'hammer'
,
'hammock'
,
'hamper'
,
'hamster'
,
'hair_dryer'
,
'hand_glass'
,
'hand_towel'
,
'handcart'
,
'handcuff'
,
'handkerchief'
,
'handle'
,
'handsaw'
,
'hardback_book'
,
'harmonium'
,
'hat'
,
'hatbox'
,
'veil'
,
'headband'
,
'headboard'
,
'headlight'
,
'headscarf'
,
'headset'
,
'headstall_(for_horses)'
,
'heart'
,
'heater'
,
'helicopter'
,
'helmet'
,
'heron'
,
'highchair'
,
'hinge'
,
'hippopotamus'
,
'hockey_stick'
,
'hog'
,
'home_plate_(baseball)'
,
'honey'
,
'fume_hood'
,
'hook'
,
'hookah'
,
'hornet'
,
'horse'
,
'hose'
,
'hot-air_balloon'
,
'hotplate'
,
'hot_sauce'
,
'hourglass'
,
'houseboat'
,
'hummingbird'
,
'hummus'
,
'polar_bear'
,
'icecream'
,
'popsicle'
,
'ice_maker'
,
'ice_pack'
,
'ice_skate'
,
'igniter'
,
'inhaler'
,
'iPod'
,
'iron_(for_clothing)'
,
'ironing_board'
,
'jacket'
,
'jam'
,
'jar'
,
'jean'
,
'jeep'
,
'jelly_bean'
,
'jersey'
,
'jet_plane'
,
'jewel'
,
'jewelry'
,
'joystick'
,
'jumpsuit'
,
'kayak'
,
'keg'
,
'kennel'
,
'kettle'
,
'key'
,
'keycard'
,
'kilt'
,
'kimono'
,
'kitchen_sink'
,
'kitchen_table'
,
'kite'
,
'kitten'
,
'kiwi_fruit'
,
'knee_pad'
,
'knife'
,
'knitting_needle'
,
'knob'
,
'knocker_(on_a_door)'
,
'koala'
,
'lab_coat'
,
'ladder'
,
'ladle'
,
'ladybug'
,
'lamb_(animal)'
,
'lamb-chop'
,
'lamp'
,
'lamppost'
,
'lampshade'
,
'lantern'
,
'lanyard'
,
'laptop_computer'
,
'lasagna'
,
'latch'
,
'lawn_mower'
,
'leather'
,
'legging_(clothing)'
,
'Lego'
,
'legume'
,
'lemon'
,
'lemonade'
,
'lettuce'
,
'license_plate'
,
'life_buoy'
,
'life_jacket'
,
'lightbulb'
,
'lightning_rod'
,
'lime'
,
'limousine'
,
'lion'
,
'lip_balm'
,
'liquor'
,
'lizard'
,
'log'
,
'lollipop'
,
'speaker_(stero_equipment)'
,
'loveseat'
,
'machine_gun'
,
'magazine'
,
'magnet'
,
'mail_slot'
,
'mailbox_(at_home)'
,
'mallard'
,
'mallet'
,
'mammoth'
,
'manatee'
,
'mandarin_orange'
,
'manger'
,
'manhole'
,
'map'
,
'marker'
,
'martini'
,
'mascot'
,
'mashed_potato'
,
'masher'
,
'mask'
,
'mast'
,
'mat_(gym_equipment)'
,
'matchbox'
,
'mattress'
,
'measuring_cup'
,
'measuring_stick'
,
'meatball'
,
'medicine'
,
'melon'
,
'microphone'
,
'microscope'
,
'microwave_oven'
,
'milestone'
,
'milk'
,
'milk_can'
,
'milkshake'
,
'minivan'
,
'mint_candy'
,
'mirror'
,
'mitten'
,
'mixer_(kitchen_tool)'
,
'money'
,
'monitor_(computer_equipment) computer_monitor'
,
'monkey'
,
'motor'
,
'motor_scooter'
,
'motor_vehicle'
,
'motorcycle'
,
'mound_(baseball)'
,
'mouse_(computer_equipment)'
,
'mousepad'
,
'muffin'
,
'mug'
,
'mushroom'
,
'music_stool'
,
'musical_instrument'
,
'nailfile'
,
'napkin'
,
'neckerchief'
,
'necklace'
,
'necktie'
,
'needle'
,
'nest'
,
'newspaper'
,
'newsstand'
,
'nightshirt'
,
'nosebag_(for_animals)'
,
'noseband_(for_animals)'
,
'notebook'
,
'notepad'
,
'nut'
,
'nutcracker'
,
'oar'
,
'octopus_(food)'
,
'octopus_(animal)'
,
'oil_lamp'
,
'olive_oil'
,
'omelet'
,
'onion'
,
'orange_(fruit)'
,
'orange_juice'
,
'ostrich'
,
'ottoman'
,
'oven'
,
'overalls_(clothing)'
,
'owl'
,
'packet'
,
'inkpad'
,
'pad'
,
'paddle'
,
'padlock'
,
'paintbrush'
,
'painting'
,
'pajamas'
,
'palette'
,
'pan_(for_cooking)'
,
'pan_(metal_container)'
,
'pancake'
,
'pantyhose'
,
'papaya'
,
'paper_plate'
,
'paper_towel'
,
'paperback_book'
,
'paperweight'
,
'parachute'
,
'parakeet'
,
'parasail_(sports)'
,
'parasol'
,
'parchment'
,
'parka'
,
'parking_meter'
,
'parrot'
,
'passenger_car_(part_of_a_train)'
,
'passenger_ship'
,
'passport'
,
'pastry'
,
'patty_(food)'
,
'pea_(food)'
,
'peach'
,
'peanut_butter'
,
'pear'
,
'peeler_(tool_for_fruit_and_vegetables)'
,
'wooden_leg'
,
'pegboard'
,
'pelican'
,
'pen'
,
'pencil'
,
'pencil_box'
,
'pencil_sharpener'
,
'pendulum'
,
'penguin'
,
'pennant'
,
'penny_(coin)'
,
'pepper'
,
'pepper_mill'
,
'perfume'
,
'persimmon'
,
'person'
,
'pet'
,
'pew_(church_bench)'
,
'phonebook'
,
'phonograph_record'
,
'piano'
,
'pickle'
,
'pickup_truck'
,
'pie'
,
'pigeon'
,
'piggy_bank'
,
'pillow'
,
'pin_(non_jewelry)'
,
'pineapple'
,
'pinecone'
,
'ping-pong_ball'
,
'pinwheel'
,
'tobacco_pipe'
,
'pipe'
,
'pistol'
,
'pita_(bread)'
,
'pitcher_(vessel_for_liquid)'
,
'pitchfork'
,
'pizza'
,
'place_mat'
,
'plate'
,
'platter'
,
'playpen'
,
'pliers'
,
'plow_(farm_equipment)'
,
'plume'
,
'pocket_watch'
,
'pocketknife'
,
'poker_(fire_stirring_tool)'
,
'pole'
,
'polo_shirt'
,
'poncho'
,
'pony'
,
'pool_table'
,
'pop_(soda)'
,
'postbox_(public)'
,
'postcard'
,
'poster'
,
'pot'
,
'flowerpot'
,
'potato'
,
'potholder'
,
'pottery'
,
'pouch'
,
'power_shovel'
,
'prawn'
,
'pretzel'
,
'printer'
,
'projectile_(weapon)'
,
'projector'
,
'propeller'
,
'prune'
,
'pudding'
,
'puffer_(fish)'
,
'puffin'
,
'pug-dog'
,
'pumpkin'
,
'puncher'
,
'puppet'
,
'puppy'
,
'quesadilla'
,
'quiche'
,
'quilt'
,
'rabbit'
,
'race_car'
,
'racket'
,
'radar'
,
'radiator'
,
'radio_receiver'
,
'radish'
,
'raft'
,
'rag_doll'
,
'raincoat'
,
'ram_(animal)'
,
'raspberry'
,
'rat'
,
'razorblade'
,
'reamer_(juicer)'
,
'rearview_mirror'
,
'receipt'
,
'recliner'
,
'record_player'
,
'reflector'
,
'remote_control'
,
'rhinoceros'
,
'rib_(food)'
,
'rifle'
,
'ring'
,
'river_boat'
,
'road_map'
,
'robe'
,
'rocking_chair'
,
'rodent'
,
'roller_skate'
,
'Rollerblade'
,
'rolling_pin'
,
'root_beer'
,
'router_(computer_equipment)'
,
'rubber_band'
,
'runner_(carpet)'
,
'plastic_bag'
,
'saddle_(on_an_animal)'
,
'saddle_blanket'
,
'saddlebag'
,
'safety_pin'
,
'sail'
,
'salad'
,
'salad_plate'
,
'salami'
,
'salmon_(fish)'
,
'salmon_(food)'
,
'salsa'
,
'saltshaker'
,
'sandal_(type_of_shoe)'
,
'sandwich'
,
'satchel'
,
'saucepan'
,
'saucer'
,
'sausage'
,
'sawhorse'
,
'saxophone'
,
'scale_(measuring_instrument)'
,
'scarecrow'
,
'scarf'
,
'school_bus'
,
'scissors'
,
'scoreboard'
,
'scraper'
,
'screwdriver'
,
'scrubbing_brush'
,
'sculpture'
,
'seabird'
,
'seahorse'
,
'seaplane'
,
'seashell'
,
'sewing_machine'
,
'shaker'
,
'shampoo'
,
'shark'
,
'sharpener'
,
'Sharpie'
,
'shaver_(electric)'
,
'shaving_cream'
,
'shawl'
,
'shears'
,
'sheep'
,
'shepherd_dog'
,
'sherbert'
,
'shield'
,
'shirt'
,
'shoe'
,
'shopping_bag'
,
'shopping_cart'
,
'short_pants'
,
'shot_glass'
,
'shoulder_bag'
,
'shovel'
,
'shower_head'
,
'shower_cap'
,
'shower_curtain'
,
'shredder_(for_paper)'
,
'signboard'
,
'silo'
,
'sink'
,
'skateboard'
,
'skewer'
,
'ski'
,
'ski_boot'
,
'ski_parka'
,
'ski_pole'
,
'skirt'
,
'skullcap'
,
'sled'
,
'sleeping_bag'
,
'sling_(bandage)'
,
'slipper_(footwear)'
,
'smoothie'
,
'snake'
,
'snowboard'
,
'snowman'
,
'snowmobile'
,
'soap'
,
'soccer_ball'
,
'sock'
,
'sofa'
,
'softball'
,
'solar_array'
,
'sombrero'
,
'soup'
,
'soup_bowl'
,
'soupspoon'
,
'sour_cream'
,
'soya_milk'
,
'space_shuttle'
,
'sparkler_(fireworks)'
,
'spatula'
,
'spear'
,
'spectacles'
,
'spice_rack'
,
'spider'
,
'crawfish'
,
'sponge'
,
'spoon'
,
'sportswear'
,
'spotlight'
,
'squid_(food)'
,
'squirrel'
,
'stagecoach'
,
'stapler_(stapling_machine)'
,
'starfish'
,
'statue_(sculpture)'
,
'steak_(food)'
,
'steak_knife'
,
'steering_wheel'
,
'stepladder'
,
'step_stool'
,
'stereo_(sound_system)'
,
'stew'
,
'stirrer'
,
'stirrup'
,
'stool'
,
'stop_sign'
,
'brake_light'
,
'stove'
,
'strainer'
,
'strap'
,
'straw_(for_drinking)'
,
'strawberry'
,
'street_sign'
,
'streetlight'
,
'string_cheese'
,
'stylus'
,
'subwoofer'
,
'sugar_bowl'
,
'sugarcane_(plant)'
,
'suit_(clothing)'
,
'sunflower'
,
'sunglasses'
,
'sunhat'
,
'surfboard'
,
'sushi'
,
'mop'
,
'sweat_pants'
,
'sweatband'
,
'sweater'
,
'sweatshirt'
,
'sweet_potato'
,
'swimsuit'
,
'sword'
,
'syringe'
,
'Tabasco_sauce'
,
'table-tennis_table'
,
'table'
,
'table_lamp'
,
'tablecloth'
,
'tachometer'
,
'taco'
,
'tag'
,
'taillight'
,
'tambourine'
,
'army_tank'
,
'tank_(storage_vessel)'
,
'tank_top_(clothing)'
,
'tape_(sticky_cloth_or_paper)'
,
'tape_measure'
,
'tapestry'
,
'tarp'
,
'tartan'
,
'tassel'
,
'tea_bag'
,
'teacup'
,
'teakettle'
,
'teapot'
,
'teddy_bear'
,
'telephone'
,
'telephone_booth'
,
'telephone_pole'
,
'telephoto_lens'
,
'television_camera'
,
'television_set'
,
'tennis_ball'
,
'tennis_racket'
,
'tequila'
,
'thermometer'
,
'thermos_bottle'
,
'thermostat'
,
'thimble'
,
'thread'
,
'thumbtack'
,
'tiara'
,
'tiger'
,
'tights_(clothing)'
,
'timer'
,
'tinfoil'
,
'tinsel'
,
'tissue_paper'
,
'toast_(food)'
,
'toaster'
,
'toaster_oven'
,
'toilet'
,
'toilet_tissue'
,
'tomato'
,
'tongs'
,
'toolbox'
,
'toothbrush'
,
'toothpaste'
,
'toothpick'
,
'cover'
,
'tortilla'
,
'tow_truck'
,
'towel'
,
'towel_rack'
,
'toy'
,
'tractor_(farm_equipment)'
,
'traffic_light'
,
'dirt_bike'
,
'trailer_truck'
,
'train_(railroad_vehicle)'
,
'trampoline'
,
'tray'
,
'trench_coat'
,
'triangle_(musical_instrument)'
,
'tricycle'
,
'tripod'
,
'trousers'
,
'truck'
,
'truffle_(chocolate)'
,
'trunk'
,
'vat'
,
'turban'
,
'turkey_(food)'
,
'turnip'
,
'turtle'
,
'turtleneck_(clothing)'
,
'typewriter'
,
'umbrella'
,
'underwear'
,
'unicycle'
,
'urinal'
,
'urn'
,
'vacuum_cleaner'
,
'vase'
,
'vending_machine'
,
'vent'
,
'vest'
,
'videotape'
,
'vinegar'
,
'violin'
,
'vodka'
,
'volleyball'
,
'vulture'
,
'waffle'
,
'waffle_iron'
,
'wagon'
,
'wagon_wheel'
,
'walking_stick'
,
'wall_clock'
,
'wall_socket'
,
'wallet'
,
'walrus'
,
'wardrobe'
,
'washbasin'
,
'automatic_washer'
,
'watch'
,
'water_bottle'
,
'water_cooler'
,
'water_faucet'
,
'water_heater'
,
'water_jug'
,
'water_gun'
,
'water_scooter'
,
'water_ski'
,
'water_tower'
,
'watering_can'
,
'watermelon'
,
'weathervane'
,
'webcam'
,
'wedding_cake'
,
'wedding_ring'
,
'wet_suit'
,
'wheel'
,
'wheelchair'
,
'whipped_cream'
,
'whistle'
,
'wig'
,
'wind_chime'
,
'windmill'
,
'window_box_(for_plants)'
,
'windshield_wiper'
,
'windsock'
,
'wine_bottle'
,
'wine_bucket'
,
'wineglass'
,
'blinder_(for_horses)'
,
'wok'
,
'wolf'
,
'wooden_spoon'
,
'wreath'
,
'wrench'
,
'wristband'
,
'wristlet'
,
'yacht'
,
'yogurt'
,
'yoke_(animal_equipment)'
,
'zebra'
,
'zucchini'
)
def
load_annotations
(
self
,
ann_file
):
try
:
import
lvis
assert
lvis
.
__version__
>=
'10.5.3'
from
lvis
import
LVIS
except
AssertionError
:
raise
AssertionError
(
'Incompatible version of lvis is installed. '
'Run pip uninstall lvis first. Then run pip '
'install mmlvis to install open-mmlab forked '
'lvis. '
)
except
ImportError
:
raise
ImportError
(
'Package lvis is not installed. Please run pip '
'install mmlvis to install open-mmlab forked '
'lvis.'
)
self
.
coco
=
LVIS
(
ann_file
)
self
.
cat_ids
=
self
.
coco
.
get_cat_ids
()
self
.
cat2label
=
{
cat_id
:
i
for
i
,
cat_id
in
enumerate
(
self
.
cat_ids
)}
self
.
img_ids
=
self
.
coco
.
get_img_ids
()
data_infos
=
[]
for
i
in
self
.
img_ids
:
info
=
self
.
coco
.
load_imgs
([
i
])[
0
]
# coco_url is used in LVISv1 instead of file_name
# e.g. http://images.cocodataset.org/train2017/000000391895.jpg
# train/val split in specified in url
info
[
'filename'
]
=
info
[
'coco_url'
].
replace
(
'http://images.cocodataset.org/'
,
''
)
data_infos
.
append
(
info
)
return
data_infos
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/__init__.py
0 → 100644
View file @
142dcf29
from
.auto_augment
import
(
AutoAugment
,
BrightnessTransform
,
ColorTransform
,
ContrastTransform
,
EqualizeTransform
,
Rotate
,
Shear
,
Translate
)
from
.compose
import
Compose
from
.formating
import
(
Collect
,
DefaultFormatBundle
,
ImageToTensor
,
ToDataContainer
,
ToTensor
,
Transpose
,
to_tensor
)
from
.instaboost
import
InstaBoost
from
.loading
import
(
LoadAnnotations
,
LoadImageFromFile
,
LoadImageFromWebcam
,
LoadMultiChannelImageFromFiles
,
LoadProposals
)
from
.test_time_aug
import
MultiScaleFlipAug
from
.transforms
import
(
Albu
,
CutOut
,
Expand
,
MinIoURandomCrop
,
Normalize
,
Pad
,
PhotoMetricDistortion
,
RandomCenterCropPad
,
RandomCrop
,
RandomFlip
,
Resize
,
SegRescale
)
__all__
=
[
'Compose'
,
'to_tensor'
,
'ToTensor'
,
'ImageToTensor'
,
'ToDataContainer'
,
'Transpose'
,
'Collect'
,
'DefaultFormatBundle'
,
'LoadAnnotations'
,
'LoadImageFromFile'
,
'LoadImageFromWebcam'
,
'LoadMultiChannelImageFromFiles'
,
'LoadProposals'
,
'MultiScaleFlipAug'
,
'Resize'
,
'RandomFlip'
,
'Pad'
,
'RandomCrop'
,
'Normalize'
,
'SegRescale'
,
'MinIoURandomCrop'
,
'Expand'
,
'PhotoMetricDistortion'
,
'Albu'
,
'InstaBoost'
,
'RandomCenterCropPad'
,
'AutoAugment'
,
'CutOut'
,
'Shear'
,
'Rotate'
,
'ColorTransform'
,
'EqualizeTransform'
,
'BrightnessTransform'
,
'ContrastTransform'
,
'Translate'
]
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/auto_augment.py
0 → 100644
View file @
142dcf29
import
copy
import
cv2
import
mmcv
import
numpy
as
np
from
..builder
import
PIPELINES
from
.compose
import
Compose
_MAX_LEVEL
=
10
def
level_to_value
(
level
,
max_value
):
"""Map from level to values based on max_value."""
return
(
level
/
_MAX_LEVEL
)
*
max_value
def
enhance_level_to_value
(
level
,
a
=
1.8
,
b
=
0.1
):
"""Map from level to values."""
return
(
level
/
_MAX_LEVEL
)
*
a
+
b
def
random_negative
(
value
,
random_negative_prob
):
"""Randomly negate value based on random_negative_prob."""
return
-
value
if
np
.
random
.
rand
()
<
random_negative_prob
else
value
def
bbox2fields
():
"""The key correspondence from bboxes to labels, masks and
segmentations."""
bbox2label
=
{
'gt_bboxes'
:
'gt_labels'
,
'gt_bboxes_ignore'
:
'gt_labels_ignore'
}
bbox2mask
=
{
'gt_bboxes'
:
'gt_masks'
,
'gt_bboxes_ignore'
:
'gt_masks_ignore'
}
bbox2seg
=
{
'gt_bboxes'
:
'gt_semantic_seg'
,
}
return
bbox2label
,
bbox2mask
,
bbox2seg
@
PIPELINES
.
register_module
()
class
AutoAugment
(
object
):
"""Auto augmentation.
This data augmentation is proposed in `Learning Data Augmentation
Strategies for Object Detection <https://arxiv.org/pdf/1906.11172>`_.
TODO: Implement 'Shear', 'Sharpness' and 'Rotate' transforms
Args:
policies (list[list[dict]]): The policies of auto augmentation. Each
policy in ``policies`` is a specific augmentation policy, and is
composed by several augmentations (dict). When AutoAugment is
called, a random policy in ``policies`` will be selected to
augment images.
Examples:
>>> replace = (104, 116, 124)
>>> policies = [
>>> [
>>> dict(type='Sharpness', prob=0.0, level=8),
>>> dict(
>>> type='Shear',
>>> prob=0.4,
>>> level=0,
>>> replace=replace,
>>> axis='x')
>>> ],
>>> [
>>> dict(
>>> type='Rotate',
>>> prob=0.6,
>>> level=10,
>>> replace=replace),
>>> dict(type='Color', prob=1.0, level=6)
>>> ]
>>> ]
>>> augmentation = AutoAugment(policies)
>>> img = np.ones(100, 100, 3)
>>> gt_bboxes = np.ones(10, 4)
>>> results = dict(img=img, gt_bboxes=gt_bboxes)
>>> results = augmentation(results)
"""
def
__init__
(
self
,
policies
):
assert
isinstance
(
policies
,
list
)
and
len
(
policies
)
>
0
,
\
'Policies must be a non-empty list.'
for
policy
in
policies
:
assert
isinstance
(
policy
,
list
)
and
len
(
policy
)
>
0
,
\
'Each policy in policies must be a non-empty list.'
for
augment
in
policy
:
assert
isinstance
(
augment
,
dict
)
and
'type'
in
augment
,
\
'Each specific augmentation must be a dict with key'
\
' "type".'
self
.
policies
=
copy
.
deepcopy
(
policies
)
self
.
transforms
=
[
Compose
(
policy
)
for
policy
in
self
.
policies
]
def
__call__
(
self
,
results
):
transform
=
np
.
random
.
choice
(
self
.
transforms
)
return
transform
(
results
)
def
__repr__
(
self
):
return
f
'
{
self
.
__class__
.
__name__
}
(policies=
{
self
.
policies
}
)'
@
PIPELINES
.
register_module
()
class
Shear
(
object
):
"""Apply Shear Transformation to image (and its corresponding bbox, mask,
segmentation).
Args:
level (int | float): The level should be in range [0,_MAX_LEVEL].
img_fill_val (int | float | tuple): The filled values for image border.
If float, the same fill value will be used for all the three
channels of image. If tuple, the should be 3 elements.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Default 255.
prob (float): The probability for performing Shear and should be in
range [0, 1].
direction (str): The direction for shear, either "horizontal"
or "vertical".
max_shear_magnitude (float): The maximum magnitude for Shear
transformation.
random_negative_prob (float): The probability that turns the
offset negative. Should be in range [0,1]
interpolation (str): Same as in :func:`mmcv.imshear`.
"""
def
__init__
(
self
,
level
,
img_fill_val
=
128
,
seg_ignore_label
=
255
,
prob
=
0.5
,
direction
=
'horizontal'
,
max_shear_magnitude
=
0.3
,
random_negative_prob
=
0.5
,
interpolation
=
'bilinear'
):
assert
isinstance
(
level
,
(
int
,
float
)),
'The level must be type '
\
f
'int or float, got
{
type
(
level
)
}
.'
assert
0
<=
level
<=
_MAX_LEVEL
,
'The level should be in range '
\
f
'[0,
{
_MAX_LEVEL
}
], got
{
level
}
.'
if
isinstance
(
img_fill_val
,
(
float
,
int
)):
img_fill_val
=
tuple
([
float
(
img_fill_val
)]
*
3
)
elif
isinstance
(
img_fill_val
,
tuple
):
assert
len
(
img_fill_val
)
==
3
,
'img_fill_val as tuple must '
\
f
'have 3 elements. got
{
len
(
img_fill_val
)
}
.'
img_fill_val
=
tuple
([
float
(
val
)
for
val
in
img_fill_val
])
else
:
raise
ValueError
(
'img_fill_val must be float or tuple with 3 elements.'
)
assert
np
.
all
([
0
<=
val
<=
255
for
val
in
img_fill_val
]),
'all '
\
'elements of img_fill_val should between range [0,255].'
\
f
'got
{
img_fill_val
}
.'
assert
0
<=
prob
<=
1.0
,
'The probability of shear should be in '
\
f
'range [0,1]. got
{
prob
}
.'
assert
direction
in
(
'horizontal'
,
'vertical'
),
'direction must '
\
f
'in be either "horizontal" or "vertical". got
{
direction
}
.'
assert
isinstance
(
max_shear_magnitude
,
float
),
'max_shear_magnitude '
\
f
'should be type float. got
{
type
(
max_shear_magnitude
)
}
.'
assert
0.
<=
max_shear_magnitude
<=
1.
,
'Defaultly '
\
'max_shear_magnitude should be in range [0,1]. '
\
f
'got
{
max_shear_magnitude
}
.'
self
.
level
=
level
self
.
magnitude
=
level_to_value
(
level
,
max_shear_magnitude
)
self
.
img_fill_val
=
img_fill_val
self
.
seg_ignore_label
=
seg_ignore_label
self
.
prob
=
prob
self
.
direction
=
direction
self
.
max_shear_magnitude
=
max_shear_magnitude
self
.
random_negative_prob
=
random_negative_prob
self
.
interpolation
=
interpolation
def
_shear_img
(
self
,
results
,
magnitude
,
direction
=
'horizontal'
,
interpolation
=
'bilinear'
):
"""Shear the image.
Args:
results (dict): Result dict from loading pipeline.
magnitude (int | float): The magnitude used for shear.
direction (str): The direction for shear, either "horizontal"
or "vertical".
interpolation (str): Same as in :func:`mmcv.imshear`.
"""
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
img
=
results
[
key
]
img_sheared
=
mmcv
.
imshear
(
img
,
magnitude
,
direction
,
border_value
=
self
.
img_fill_val
,
interpolation
=
interpolation
)
results
[
key
]
=
img_sheared
.
astype
(
img
.
dtype
)
def
_shear_bboxes
(
self
,
results
,
magnitude
):
"""Shear the bboxes."""
h
,
w
,
c
=
results
[
'img_shape'
]
if
self
.
direction
==
'horizontal'
:
shear_matrix
=
np
.
stack
([[
1
,
magnitude
],
[
0
,
1
]]).
astype
(
np
.
float32
)
# [2, 2]
else
:
shear_matrix
=
np
.
stack
([[
1
,
0
],
[
magnitude
,
1
]]).
astype
(
np
.
float32
)
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
min_x
,
min_y
,
max_x
,
max_y
=
np
.
split
(
results
[
key
],
results
[
key
].
shape
[
-
1
],
axis
=-
1
)
coordinates
=
np
.
stack
([[
min_x
,
min_y
],
[
max_x
,
min_y
],
[
min_x
,
max_y
],
[
max_x
,
max_y
]])
# [4, 2, nb_box, 1]
coordinates
=
coordinates
[...,
0
].
transpose
(
(
2
,
1
,
0
)).
astype
(
np
.
float32
)
# [nb_box, 2, 4]
new_coords
=
np
.
matmul
(
shear_matrix
[
None
,
:,
:],
coordinates
)
# [nb_box, 2, 4]
min_x
=
np
.
min
(
new_coords
[:,
0
,
:],
axis
=-
1
)
min_y
=
np
.
min
(
new_coords
[:,
1
,
:],
axis
=-
1
)
max_x
=
np
.
max
(
new_coords
[:,
0
,
:],
axis
=-
1
)
max_y
=
np
.
max
(
new_coords
[:,
1
,
:],
axis
=-
1
)
min_x
=
np
.
clip
(
min_x
,
a_min
=
0
,
a_max
=
w
)
min_y
=
np
.
clip
(
min_y
,
a_min
=
0
,
a_max
=
h
)
max_x
=
np
.
clip
(
max_x
,
a_min
=
min_x
,
a_max
=
w
)
max_y
=
np
.
clip
(
max_y
,
a_min
=
min_y
,
a_max
=
h
)
results
[
key
]
=
np
.
stack
([
min_x
,
min_y
,
max_x
,
max_y
],
axis
=-
1
).
astype
(
results
[
key
].
dtype
)
def
_shear_masks
(
self
,
results
,
magnitude
,
direction
=
'horizontal'
,
fill_val
=
0
,
interpolation
=
'bilinear'
):
"""Shear the masks."""
h
,
w
,
c
=
results
[
'img_shape'
]
for
key
in
results
.
get
(
'mask_fields'
,
[]):
masks
=
results
[
key
]
results
[
key
]
=
masks
.
shear
((
h
,
w
),
magnitude
,
direction
,
border_value
=
fill_val
,
interpolation
=
interpolation
)
def
_shear_seg
(
self
,
results
,
magnitude
,
direction
=
'horizontal'
,
fill_val
=
255
,
interpolation
=
'bilinear'
):
"""Shear the segmentation maps."""
for
key
in
results
.
get
(
'seg_fields'
,
[]):
seg
=
results
[
key
]
results
[
key
]
=
mmcv
.
imshear
(
seg
,
magnitude
,
direction
,
border_value
=
fill_val
,
interpolation
=
interpolation
).
astype
(
seg
.
dtype
)
def
_filter_invalid
(
self
,
results
,
min_bbox_size
=
0
):
"""Filter bboxes and corresponding masks too small after shear
augmentation."""
bbox2label
,
bbox2mask
,
_
=
bbox2fields
()
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
bbox_w
=
results
[
key
][:,
2
]
-
results
[
key
][:,
0
]
bbox_h
=
results
[
key
][:,
3
]
-
results
[
key
][:,
1
]
valid_inds
=
(
bbox_w
>
min_bbox_size
)
&
(
bbox_h
>
min_bbox_size
)
valid_inds
=
np
.
nonzero
(
valid_inds
)[
0
]
results
[
key
]
=
results
[
key
][
valid_inds
]
# label fields. e.g. gt_labels and gt_labels_ignore
label_key
=
bbox2label
.
get
(
key
)
if
label_key
in
results
:
results
[
label_key
]
=
results
[
label_key
][
valid_inds
]
# mask fields, e.g. gt_masks and gt_masks_ignore
mask_key
=
bbox2mask
.
get
(
key
)
if
mask_key
in
results
:
results
[
mask_key
]
=
results
[
mask_key
][
valid_inds
]
def
__call__
(
self
,
results
):
"""Call function to shear images, bounding boxes, masks and semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Sheared results.
"""
if
np
.
random
.
rand
()
>
self
.
prob
:
return
results
magnitude
=
random_negative
(
self
.
magnitude
,
self
.
random_negative_prob
)
self
.
_shear_img
(
results
,
magnitude
,
self
.
direction
,
self
.
interpolation
)
self
.
_shear_bboxes
(
results
,
magnitude
)
# fill_val set to 0 for background of mask.
self
.
_shear_masks
(
results
,
magnitude
,
self
.
direction
,
fill_val
=
0
,
interpolation
=
self
.
interpolation
)
self
.
_shear_seg
(
results
,
magnitude
,
self
.
direction
,
fill_val
=
self
.
seg_ignore_label
,
interpolation
=
self
.
interpolation
)
self
.
_filter_invalid
(
results
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(level=
{
self
.
level
}
, '
repr_str
+=
f
'img_fill_val=
{
self
.
img_fill_val
}
, '
repr_str
+=
f
'seg_ignore_label=
{
self
.
seg_ignore_label
}
, '
repr_str
+=
f
'prob=
{
self
.
prob
}
, '
repr_str
+=
f
'direction=
{
self
.
direction
}
, '
repr_str
+=
f
'max_shear_magnitude=
{
self
.
max_shear_magnitude
}
, '
repr_str
+=
f
'random_negative_prob=
{
self
.
random_negative_prob
}
, '
repr_str
+=
f
'interpolation=
{
self
.
interpolation
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
Rotate
(
object
):
"""Apply Rotate Transformation to image (and its corresponding bbox, mask,
segmentation).
Args:
level (int | float): The level should be in range (0,_MAX_LEVEL].
scale (int | float): Isotropic scale factor. Same in
``mmcv.imrotate``.
center (int | float | tuple[float]): Center point (w, h) of the
rotation in the source image. If None, the center of the
image will be used. Same in ``mmcv.imrotate``.
img_fill_val (int | float | tuple): The fill value for image border.
If float, the same value will be used for all the three
channels of image. If tuple, the should be 3 elements (e.g.
equals the number of channels for image).
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Default 255.
prob (float): The probability for perform transformation and
should be in range 0 to 1.
max_rotate_angle (int | float): The maximum angles for rotate
transformation.
random_negative_prob (float): The probability that turns the
offset negative.
"""
def
__init__
(
self
,
level
,
scale
=
1
,
center
=
None
,
img_fill_val
=
128
,
seg_ignore_label
=
255
,
prob
=
0.5
,
max_rotate_angle
=
30
,
random_negative_prob
=
0.5
):
assert
isinstance
(
level
,
(
int
,
float
)),
\
f
'The level must be type int or float. got
{
type
(
level
)
}
.'
assert
0
<=
level
<=
_MAX_LEVEL
,
\
f
'The level should be in range (0,
{
_MAX_LEVEL
}
]. got
{
level
}
.'
assert
isinstance
(
scale
,
(
int
,
float
)),
\
f
'The scale must be type int or float. got type
{
type
(
scale
)
}
.'
if
isinstance
(
center
,
(
int
,
float
)):
center
=
(
center
,
center
)
elif
isinstance
(
center
,
tuple
):
assert
len
(
center
)
==
2
,
'center with type tuple must have '
\
f
'2 elements. got
{
len
(
center
)
}
elements.'
else
:
assert
center
is
None
,
'center must be None or type int, '
\
f
'float or tuple, got type
{
type
(
center
)
}
.'
if
isinstance
(
img_fill_val
,
(
float
,
int
)):
img_fill_val
=
tuple
([
float
(
img_fill_val
)]
*
3
)
elif
isinstance
(
img_fill_val
,
tuple
):
assert
len
(
img_fill_val
)
==
3
,
'img_fill_val as tuple must '
\
f
'have 3 elements. got
{
len
(
img_fill_val
)
}
.'
img_fill_val
=
tuple
([
float
(
val
)
for
val
in
img_fill_val
])
else
:
raise
ValueError
(
'img_fill_val must be float or tuple with 3 elements.'
)
assert
np
.
all
([
0
<=
val
<=
255
for
val
in
img_fill_val
]),
\
'all elements of img_fill_val should between range [0,255]. '
\
f
'got
{
img_fill_val
}
.'
assert
0
<=
prob
<=
1.0
,
'The probability should be in range [0,1]. '
\
'got {prob}.'
assert
isinstance
(
max_rotate_angle
,
(
int
,
float
)),
'max_rotate_angle '
\
f
'should be type int or float. got type
{
type
(
max_rotate_angle
)
}
.'
self
.
level
=
level
self
.
scale
=
scale
# Rotation angle in degrees. Positive values mean
# clockwise rotation.
self
.
angle
=
level_to_value
(
level
,
max_rotate_angle
)
self
.
center
=
center
self
.
img_fill_val
=
img_fill_val
self
.
seg_ignore_label
=
seg_ignore_label
self
.
prob
=
prob
self
.
max_rotate_angle
=
max_rotate_angle
self
.
random_negative_prob
=
random_negative_prob
def
_rotate_img
(
self
,
results
,
angle
,
center
=
None
,
scale
=
1.0
):
"""Rotate the image.
Args:
results (dict): Result dict from loading pipeline.
angle (float): Rotation angle in degrees, positive values
mean clockwise rotation. Same in ``mmcv.imrotate``.
center (tuple[float], optional): Center point (w, h) of the
rotation. Same in ``mmcv.imrotate``.
scale (int | float): Isotropic scale factor. Same in
``mmcv.imrotate``.
"""
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
img
=
results
[
key
].
copy
()
img_rotated
=
mmcv
.
imrotate
(
img
,
angle
,
center
,
scale
,
border_value
=
self
.
img_fill_val
)
results
[
key
]
=
img_rotated
.
astype
(
img
.
dtype
)
def
_rotate_bboxes
(
self
,
results
,
rotate_matrix
):
"""Rotate the bboxes."""
h
,
w
,
c
=
results
[
'img_shape'
]
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
min_x
,
min_y
,
max_x
,
max_y
=
np
.
split
(
results
[
key
],
results
[
key
].
shape
[
-
1
],
axis
=-
1
)
coordinates
=
np
.
stack
([[
min_x
,
min_y
],
[
max_x
,
min_y
],
[
min_x
,
max_y
],
[
max_x
,
max_y
]])
# [4, 2, nb_bbox, 1]
# pad 1 to convert from format [x, y] to homogeneous
# coordinates format [x, y, 1]
coordinates
=
np
.
concatenate
(
(
coordinates
,
np
.
ones
((
4
,
1
,
coordinates
.
shape
[
2
],
1
),
coordinates
.
dtype
)),
axis
=
1
)
# [4, 3, nb_bbox, 1]
coordinates
=
coordinates
.
transpose
(
(
2
,
0
,
1
,
3
))
# [nb_bbox, 4, 3, 1]
rotated_coords
=
np
.
matmul
(
rotate_matrix
,
coordinates
)
# [nb_bbox, 4, 2, 1]
rotated_coords
=
rotated_coords
[...,
0
]
# [nb_bbox, 4, 2]
min_x
,
min_y
=
np
.
min
(
rotated_coords
[:,
:,
0
],
axis
=
1
),
np
.
min
(
rotated_coords
[:,
:,
1
],
axis
=
1
)
max_x
,
max_y
=
np
.
max
(
rotated_coords
[:,
:,
0
],
axis
=
1
),
np
.
max
(
rotated_coords
[:,
:,
1
],
axis
=
1
)
min_x
,
min_y
=
np
.
clip
(
min_x
,
a_min
=
0
,
a_max
=
w
),
np
.
clip
(
min_y
,
a_min
=
0
,
a_max
=
h
)
max_x
,
max_y
=
np
.
clip
(
max_x
,
a_min
=
min_x
,
a_max
=
w
),
np
.
clip
(
max_y
,
a_min
=
min_y
,
a_max
=
h
)
results
[
key
]
=
np
.
stack
([
min_x
,
min_y
,
max_x
,
max_y
],
axis
=-
1
).
astype
(
results
[
key
].
dtype
)
def
_rotate_masks
(
self
,
results
,
angle
,
center
=
None
,
scale
=
1.0
,
fill_val
=
0
):
"""Rotate the masks."""
h
,
w
,
c
=
results
[
'img_shape'
]
for
key
in
results
.
get
(
'mask_fields'
,
[]):
masks
=
results
[
key
]
results
[
key
]
=
masks
.
rotate
((
h
,
w
),
angle
,
center
,
scale
,
fill_val
)
def
_rotate_seg
(
self
,
results
,
angle
,
center
=
None
,
scale
=
1.0
,
fill_val
=
255
):
"""Rotate the segmentation map."""
for
key
in
results
.
get
(
'seg_fields'
,
[]):
seg
=
results
[
key
].
copy
()
results
[
key
]
=
mmcv
.
imrotate
(
seg
,
angle
,
center
,
scale
,
border_value
=
fill_val
).
astype
(
seg
.
dtype
)
def
_filter_invalid
(
self
,
results
,
min_bbox_size
=
0
):
"""Filter bboxes and corresponding masks too small after rotate
augmentation."""
bbox2label
,
bbox2mask
,
_
=
bbox2fields
()
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
bbox_w
=
results
[
key
][:,
2
]
-
results
[
key
][:,
0
]
bbox_h
=
results
[
key
][:,
3
]
-
results
[
key
][:,
1
]
valid_inds
=
(
bbox_w
>
min_bbox_size
)
&
(
bbox_h
>
min_bbox_size
)
valid_inds
=
np
.
nonzero
(
valid_inds
)[
0
]
results
[
key
]
=
results
[
key
][
valid_inds
]
# label fields. e.g. gt_labels and gt_labels_ignore
label_key
=
bbox2label
.
get
(
key
)
if
label_key
in
results
:
results
[
label_key
]
=
results
[
label_key
][
valid_inds
]
# mask fields, e.g. gt_masks and gt_masks_ignore
mask_key
=
bbox2mask
.
get
(
key
)
if
mask_key
in
results
:
results
[
mask_key
]
=
results
[
mask_key
][
valid_inds
]
def
__call__
(
self
,
results
):
"""Call function to rotate images, bounding boxes, masks and semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Rotated results.
"""
if
np
.
random
.
rand
()
>
self
.
prob
:
return
results
h
,
w
=
results
[
'img'
].
shape
[:
2
]
center
=
self
.
center
if
center
is
None
:
center
=
((
w
-
1
)
*
0.5
,
(
h
-
1
)
*
0.5
)
angle
=
random_negative
(
self
.
angle
,
self
.
random_negative_prob
)
self
.
_rotate_img
(
results
,
angle
,
center
,
self
.
scale
)
rotate_matrix
=
cv2
.
getRotationMatrix2D
(
center
,
-
angle
,
self
.
scale
)
self
.
_rotate_bboxes
(
results
,
rotate_matrix
)
self
.
_rotate_masks
(
results
,
angle
,
center
,
self
.
scale
,
fill_val
=
0
)
self
.
_rotate_seg
(
results
,
angle
,
center
,
self
.
scale
,
fill_val
=
self
.
seg_ignore_label
)
self
.
_filter_invalid
(
results
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(level=
{
self
.
level
}
, '
repr_str
+=
f
'scale=
{
self
.
scale
}
, '
repr_str
+=
f
'center=
{
self
.
center
}
, '
repr_str
+=
f
'img_fill_val=
{
self
.
img_fill_val
}
, '
repr_str
+=
f
'seg_ignore_label=
{
self
.
seg_ignore_label
}
, '
repr_str
+=
f
'prob=
{
self
.
prob
}
, '
repr_str
+=
f
'max_rotate_angle=
{
self
.
max_rotate_angle
}
, '
repr_str
+=
f
'random_negative_prob=
{
self
.
random_negative_prob
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
Translate
(
object
):
"""Translate the images, bboxes, masks and segmentation maps horizontally
or vertically.
Args:
level (int | float): The level for Translate and should be in
range [0,_MAX_LEVEL].
prob (float): The probability for performing translation and
should be in range [0, 1].
img_fill_val (int | float | tuple): The filled value for image
border. If float, the same fill value will be used for all
the three channels of image. If tuple, the should be 3
elements (e.g. equals the number of channels for image).
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Default 255.
direction (str): The translate direction, either "horizontal"
or "vertical".
max_translate_offset (int | float): The maximum pixel's offset for
Translate.
random_negative_prob (float): The probability that turns the
offset negative.
min_size (int | float): The minimum pixel for filtering
invalid bboxes after the translation.
"""
def
__init__
(
self
,
level
,
prob
=
0.5
,
img_fill_val
=
128
,
seg_ignore_label
=
255
,
direction
=
'horizontal'
,
max_translate_offset
=
250.
,
random_negative_prob
=
0.5
,
min_size
=
0
):
assert
isinstance
(
level
,
(
int
,
float
)),
\
'The level must be type int or float.'
assert
0
<=
level
<=
_MAX_LEVEL
,
\
'The level used for calculating Translate
\'
s offset should be '
\
'in range [0,_MAX_LEVEL]'
assert
0
<=
prob
<=
1.0
,
\
'The probability of translation should be in range [0, 1].'
if
isinstance
(
img_fill_val
,
(
float
,
int
)):
img_fill_val
=
tuple
([
float
(
img_fill_val
)]
*
3
)
elif
isinstance
(
img_fill_val
,
tuple
):
assert
len
(
img_fill_val
)
==
3
,
\
'img_fill_val as tuple must have 3 elements.'
img_fill_val
=
tuple
([
float
(
val
)
for
val
in
img_fill_val
])
else
:
raise
ValueError
(
'img_fill_val must be type float or tuple.'
)
assert
np
.
all
([
0
<=
val
<=
255
for
val
in
img_fill_val
]),
\
'all elements of img_fill_val should between range [0,255].'
assert
direction
in
(
'horizontal'
,
'vertical'
),
\
'direction should be "horizontal" or "vertical".'
assert
isinstance
(
max_translate_offset
,
(
int
,
float
)),
\
'The max_translate_offset must be type int or float.'
# the offset used for translation
self
.
offset
=
int
(
level_to_value
(
level
,
max_translate_offset
))
self
.
level
=
level
self
.
prob
=
prob
self
.
img_fill_val
=
img_fill_val
self
.
seg_ignore_label
=
seg_ignore_label
self
.
direction
=
direction
self
.
max_translate_offset
=
max_translate_offset
self
.
random_negative_prob
=
random_negative_prob
self
.
min_size
=
min_size
def
_translate_img
(
self
,
results
,
offset
,
direction
=
'horizontal'
):
"""Translate the image.
Args:
results (dict): Result dict from loading pipeline.
offset (int | float): The offset for translate.
direction (str): The translate direction, either "horizontal"
or "vertical".
"""
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
img
=
results
[
key
].
copy
()
results
[
key
]
=
mmcv
.
imtranslate
(
img
,
offset
,
direction
,
self
.
img_fill_val
).
astype
(
img
.
dtype
)
def
_translate_bboxes
(
self
,
results
,
offset
):
"""Shift bboxes horizontally or vertically, according to offset."""
h
,
w
,
c
=
results
[
'img_shape'
]
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
min_x
,
min_y
,
max_x
,
max_y
=
np
.
split
(
results
[
key
],
results
[
key
].
shape
[
-
1
],
axis
=-
1
)
if
self
.
direction
==
'horizontal'
:
min_x
=
np
.
maximum
(
0
,
min_x
+
offset
)
max_x
=
np
.
minimum
(
w
,
max_x
+
offset
)
elif
self
.
direction
==
'vertical'
:
min_y
=
np
.
maximum
(
0
,
min_y
+
offset
)
max_y
=
np
.
minimum
(
h
,
max_y
+
offset
)
# the boxs translated outside of image will be filtered along with
# the corresponding masks, by invoking ``_filter_invalid``.
results
[
key
]
=
np
.
concatenate
([
min_x
,
min_y
,
max_x
,
max_y
],
axis
=-
1
)
def
_translate_masks
(
self
,
results
,
offset
,
direction
=
'horizontal'
,
fill_val
=
0
):
"""Translate masks horizontally or vertically."""
h
,
w
,
c
=
results
[
'img_shape'
]
for
key
in
results
.
get
(
'mask_fields'
,
[]):
masks
=
results
[
key
]
results
[
key
]
=
masks
.
translate
((
h
,
w
),
offset
,
direction
,
fill_val
)
def
_translate_seg
(
self
,
results
,
offset
,
direction
=
'horizontal'
,
fill_val
=
255
):
"""Translate segmentation maps horizontally or vertically."""
for
key
in
results
.
get
(
'seg_fields'
,
[]):
seg
=
results
[
key
].
copy
()
results
[
key
]
=
mmcv
.
imtranslate
(
seg
,
offset
,
direction
,
fill_val
).
astype
(
seg
.
dtype
)
def
_filter_invalid
(
self
,
results
,
min_size
=
0
):
"""Filter bboxes and masks too small or translated out of image."""
bbox2label
,
bbox2mask
,
_
=
bbox2fields
()
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
bbox_w
=
results
[
key
][:,
2
]
-
results
[
key
][:,
0
]
bbox_h
=
results
[
key
][:,
3
]
-
results
[
key
][:,
1
]
valid_inds
=
(
bbox_w
>
min_size
)
&
(
bbox_h
>
min_size
)
valid_inds
=
np
.
nonzero
(
valid_inds
)[
0
]
results
[
key
]
=
results
[
key
][
valid_inds
]
# label fields. e.g. gt_labels and gt_labels_ignore
label_key
=
bbox2label
.
get
(
key
)
if
label_key
in
results
:
results
[
label_key
]
=
results
[
label_key
][
valid_inds
]
# mask fields, e.g. gt_masks and gt_masks_ignore
mask_key
=
bbox2mask
.
get
(
key
)
if
mask_key
in
results
:
results
[
mask_key
]
=
results
[
mask_key
][
valid_inds
]
return
results
def
__call__
(
self
,
results
):
"""Call function to translate images, bounding boxes, masks and
semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Translated results.
"""
if
np
.
random
.
rand
()
>
self
.
prob
:
return
results
offset
=
random_negative
(
self
.
offset
,
self
.
random_negative_prob
)
self
.
_translate_img
(
results
,
offset
,
self
.
direction
)
self
.
_translate_bboxes
(
results
,
offset
)
# fill_val defaultly 0 for BitmapMasks and None for PolygonMasks.
self
.
_translate_masks
(
results
,
offset
,
self
.
direction
)
# fill_val set to ``seg_ignore_label`` for the ignored value
# of segmentation map.
self
.
_translate_seg
(
results
,
offset
,
self
.
direction
,
fill_val
=
self
.
seg_ignore_label
)
self
.
_filter_invalid
(
results
,
min_size
=
self
.
min_size
)
return
results
@
PIPELINES
.
register_module
()
class
ColorTransform
(
object
):
"""Apply Color transformation to image. The bboxes, masks, and
segmentations are not modified.
Args:
level (int | float): Should be in range [0,_MAX_LEVEL].
prob (float): The probability for performing Color transformation.
"""
def
__init__
(
self
,
level
,
prob
=
0.5
):
assert
isinstance
(
level
,
(
int
,
float
)),
\
'The level must be type int or float.'
assert
0
<=
level
<=
_MAX_LEVEL
,
\
'The level should be in range [0,_MAX_LEVEL].'
assert
0
<=
prob
<=
1.0
,
\
'The probability should be in range [0,1].'
self
.
level
=
level
self
.
prob
=
prob
self
.
factor
=
enhance_level_to_value
(
level
)
def
_adjust_color_img
(
self
,
results
,
factor
=
1.0
):
"""Apply Color transformation to image."""
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
# NOTE defaultly the image should be BGR format
img
=
results
[
key
]
results
[
key
]
=
mmcv
.
adjust_color
(
img
,
factor
).
astype
(
img
.
dtype
)
def
__call__
(
self
,
results
):
"""Call function for Color transformation.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Colored results.
"""
if
np
.
random
.
rand
()
>
self
.
prob
:
return
results
self
.
_adjust_color_img
(
results
,
self
.
factor
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(level=
{
self
.
level
}
, '
repr_str
+=
f
'prob=
{
self
.
prob
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
EqualizeTransform
(
object
):
"""Apply Equalize transformation to image. The bboxes, masks and
segmentations are not modified.
Args:
prob (float): The probability for performing Equalize transformation.
"""
def
__init__
(
self
,
prob
=
0.5
):
assert
0
<=
prob
<=
1.0
,
\
'The probability should be in range [0,1].'
self
.
prob
=
prob
def
_imequalize
(
self
,
results
):
"""Equalizes the histogram of one image."""
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
img
=
results
[
key
]
results
[
key
]
=
mmcv
.
imequalize
(
img
).
astype
(
img
.
dtype
)
def
__call__
(
self
,
results
):
"""Call function for Equalize transformation.
Args:
results (dict): Results dict from loading pipeline.
Returns:
dict: Results after the transformation.
"""
if
np
.
random
.
rand
()
>
self
.
prob
:
return
results
self
.
_imequalize
(
results
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(prob=
{
self
.
prob
}
)'
@
PIPELINES
.
register_module
()
class
BrightnessTransform
(
object
):
"""Apply Brightness transformation to image. The bboxes, masks and
segmentations are not modified.
Args:
level (int | float): Should be in range [0,_MAX_LEVEL].
prob (float): The probability for performing Brightness transformation.
"""
def
__init__
(
self
,
level
,
prob
=
0.5
):
assert
isinstance
(
level
,
(
int
,
float
)),
\
'The level must be type int or float.'
assert
0
<=
level
<=
_MAX_LEVEL
,
\
'The level should be in range [0,_MAX_LEVEL].'
assert
0
<=
prob
<=
1.0
,
\
'The probability should be in range [0,1].'
self
.
level
=
level
self
.
prob
=
prob
self
.
factor
=
enhance_level_to_value
(
level
)
def
_adjust_brightness_img
(
self
,
results
,
factor
=
1.0
):
"""Adjust the brightness of image."""
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
img
=
results
[
key
]
results
[
key
]
=
mmcv
.
adjust_brightness
(
img
,
factor
).
astype
(
img
.
dtype
)
def
__call__
(
self
,
results
):
"""Call function for Brightness transformation.
Args:
results (dict): Results dict from loading pipeline.
Returns:
dict: Results after the transformation.
"""
if
np
.
random
.
rand
()
>
self
.
prob
:
return
results
self
.
_adjust_brightness_img
(
results
,
self
.
factor
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(level=
{
self
.
level
}
, '
repr_str
+=
f
'prob=
{
self
.
prob
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
ContrastTransform
(
object
):
"""Apply Contrast transformation to image. The bboxes, masks and
segmentations are not modified.
Args:
level (int | float): Should be in range [0,_MAX_LEVEL].
prob (float): The probability for performing Contrast transformation.
"""
def
__init__
(
self
,
level
,
prob
=
0.5
):
assert
isinstance
(
level
,
(
int
,
float
)),
\
'The level must be type int or float.'
assert
0
<=
level
<=
_MAX_LEVEL
,
\
'The level should be in range [0,_MAX_LEVEL].'
assert
0
<=
prob
<=
1.0
,
\
'The probability should be in range [0,1].'
self
.
level
=
level
self
.
prob
=
prob
self
.
factor
=
enhance_level_to_value
(
level
)
def
_adjust_contrast_img
(
self
,
results
,
factor
=
1.0
):
"""Adjust the image contrast."""
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
img
=
results
[
key
]
results
[
key
]
=
mmcv
.
adjust_contrast
(
img
,
factor
).
astype
(
img
.
dtype
)
def
__call__
(
self
,
results
):
"""Call function for Contrast transformation.
Args:
results (dict): Results dict from loading pipeline.
Returns:
dict: Results after the transformation.
"""
if
np
.
random
.
rand
()
>
self
.
prob
:
return
results
self
.
_adjust_contrast_img
(
results
,
self
.
factor
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(level=
{
self
.
level
}
, '
repr_str
+=
f
'prob=
{
self
.
prob
}
)'
return
repr_str
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/compose.py
0 → 100644
View file @
142dcf29
import
collections
from
mmcv.utils
import
build_from_cfg
from
..builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
Compose
(
object
):
"""Compose multiple transforms sequentially.
Args:
transforms (Sequence[dict | callable]): Sequence of transform object or
config dict to be composed.
"""
def
__init__
(
self
,
transforms
):
assert
isinstance
(
transforms
,
collections
.
abc
.
Sequence
)
self
.
transforms
=
[]
for
transform
in
transforms
:
if
isinstance
(
transform
,
dict
):
transform
=
build_from_cfg
(
transform
,
PIPELINES
)
self
.
transforms
.
append
(
transform
)
elif
callable
(
transform
):
self
.
transforms
.
append
(
transform
)
else
:
raise
TypeError
(
'transform must be callable or a dict'
)
def
__call__
(
self
,
data
):
"""Call function to apply transforms sequentially.
Args:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
"""
for
t
in
self
.
transforms
:
data
=
t
(
data
)
if
data
is
None
:
return
None
return
data
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
'('
for
t
in
self
.
transforms
:
format_string
+=
'
\n
'
format_string
+=
f
'
{
t
}
'
format_string
+=
'
\n
)'
return
format_string
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/formating.py
0 → 100644
View file @
142dcf29
from
collections.abc
import
Sequence
import
mmcv
import
numpy
as
np
import
torch
from
mmcv.parallel
import
DataContainer
as
DC
from
..builder
import
PIPELINES
def
to_tensor
(
data
):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
"""
if
isinstance
(
data
,
torch
.
Tensor
):
return
data
elif
isinstance
(
data
,
np
.
ndarray
):
return
torch
.
from_numpy
(
data
)
elif
isinstance
(
data
,
Sequence
)
and
not
mmcv
.
is_str
(
data
):
return
torch
.
tensor
(
data
)
elif
isinstance
(
data
,
int
):
return
torch
.
LongTensor
([
data
])
elif
isinstance
(
data
,
float
):
return
torch
.
FloatTensor
([
data
])
else
:
raise
TypeError
(
f
'type
{
type
(
data
)
}
cannot be converted to tensor.'
)
@
PIPELINES
.
register_module
()
class
ToTensor
(
object
):
"""Convert some results to :obj:`torch.Tensor` by given keys.
Args:
keys (Sequence[str]): Keys that need to be converted to Tensor.
"""
def
__init__
(
self
,
keys
):
self
.
keys
=
keys
def
__call__
(
self
,
results
):
"""Call function to convert data in results to :obj:`torch.Tensor`.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data converted
to :obj:`torch.Tensor`.
"""
for
key
in
self
.
keys
:
results
[
key
]
=
to_tensor
(
results
[
key
])
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(keys=
{
self
.
keys
}
)'
@
PIPELINES
.
register_module
()
class
ImageToTensor
(
object
):
"""Convert image to :obj:`torch.Tensor` by given keys.
The dimension order of input image is (H, W, C). The pipeline will convert
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
(1, H, W).
Args:
keys (Sequence[str]): Key of images to be converted to Tensor.
"""
def
__init__
(
self
,
keys
):
self
.
keys
=
keys
def
__call__
(
self
,
results
):
"""Call function to convert image in results to :obj:`torch.Tensor` and
transpose the channel order.
Args:
results (dict): Result dict contains the image data to convert.
Returns:
dict: The result dict contains the image converted
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
"""
for
key
in
self
.
keys
:
img
=
results
[
key
]
if
len
(
img
.
shape
)
<
3
:
img
=
np
.
expand_dims
(
img
,
-
1
)
results
[
key
]
=
to_tensor
(
img
.
transpose
(
2
,
0
,
1
))
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(keys=
{
self
.
keys
}
)'
@
PIPELINES
.
register_module
()
class
Transpose
(
object
):
"""Transpose some results by given keys.
Args:
keys (Sequence[str]): Keys of results to be transposed.
order (Sequence[int]): Order of transpose.
"""
def
__init__
(
self
,
keys
,
order
):
self
.
keys
=
keys
self
.
order
=
order
def
__call__
(
self
,
results
):
"""Call function to transpose the channel order of data in results.
Args:
results (dict): Result dict contains the data to transpose.
Returns:
dict: The result dict contains the data transposed to
\
``self.order``.
"""
for
key
in
self
.
keys
:
results
[
key
]
=
results
[
key
].
transpose
(
self
.
order
)
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
\
f
'(keys=
{
self
.
keys
}
, order=
{
self
.
order
}
)'
@
PIPELINES
.
register_module
()
class
ToDataContainer
(
object
):
"""Convert results to :obj:`mmcv.DataContainer` by given fields.
Args:
fields (Sequence[dict]): Each field is a dict like
``dict(key='xxx', **kwargs)``. The ``key`` in result will
be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
Default: ``(dict(key='img', stack=True), dict(key='gt_bboxes'),
dict(key='gt_labels'))``.
"""
def
__init__
(
self
,
fields
=
(
dict
(
key
=
'img'
,
stack
=
True
),
dict
(
key
=
'gt_bboxes'
),
dict
(
key
=
'gt_labels'
))):
self
.
fields
=
fields
def
__call__
(
self
,
results
):
"""Call function to convert data in results to
:obj:`mmcv.DataContainer`.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data converted to
\
:obj:`mmcv.DataContainer`.
"""
for
field
in
self
.
fields
:
field
=
field
.
copy
()
key
=
field
.
pop
(
'key'
)
results
[
key
]
=
DC
(
results
[
key
],
**
field
)
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(fields=
{
self
.
fields
}
)'
@
PIPELINES
.
register_module
()
class
DefaultFormatBundle
(
object
):
"""Default formatting bundle.
It simplifies the pipeline of formatting common fields, including "img",
"proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg".
These fields are formatted as follows.
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
- proposals: (1)to tensor, (2)to DataContainer
- gt_bboxes: (1)to tensor, (2)to DataContainer
- gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
- gt_labels: (1)to tensor, (2)to DataContainer
- gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True)
- gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
\
(3)to DataContainer (stack=True)
"""
def
__call__
(
self
,
results
):
"""Call function to transform and format common fields in results.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data that is formatted with
\
default bundle.
"""
if
'img'
in
results
:
img
=
results
[
'img'
]
# add default meta keys
results
=
self
.
_add_default_meta_keys
(
results
)
if
len
(
img
.
shape
)
<
3
:
img
=
np
.
expand_dims
(
img
,
-
1
)
img
=
np
.
ascontiguousarray
(
img
.
transpose
(
2
,
0
,
1
))
results
[
'img'
]
=
DC
(
to_tensor
(
img
),
stack
=
True
)
for
key
in
[
'proposals'
,
'gt_bboxes'
,
'gt_bboxes_ignore'
,
'gt_labels'
]:
if
key
not
in
results
:
continue
results
[
key
]
=
DC
(
to_tensor
(
results
[
key
]))
if
'gt_masks'
in
results
:
results
[
'gt_masks'
]
=
DC
(
results
[
'gt_masks'
],
cpu_only
=
True
)
if
'gt_semantic_seg'
in
results
:
results
[
'gt_semantic_seg'
]
=
DC
(
to_tensor
(
results
[
'gt_semantic_seg'
][
None
,
...]),
stack
=
True
)
return
results
def
_add_default_meta_keys
(
self
,
results
):
"""Add default meta keys.
We set default meta keys including `pad_shape`, `scale_factor` and
`img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and
`Pad` are implemented during the whole pipeline.
Args:
results (dict): Result dict contains the data to convert.
Returns:
results (dict): Updated result dict contains the data to convert.
"""
img
=
results
[
'img'
]
results
.
setdefault
(
'pad_shape'
,
img
.
shape
)
results
.
setdefault
(
'scale_factor'
,
1.0
)
num_channels
=
1
if
len
(
img
.
shape
)
<
3
else
img
.
shape
[
2
]
results
.
setdefault
(
'img_norm_cfg'
,
dict
(
mean
=
np
.
zeros
(
num_channels
,
dtype
=
np
.
float32
),
std
=
np
.
ones
(
num_channels
,
dtype
=
np
.
float32
),
to_rgb
=
False
))
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
@
PIPELINES
.
register_module
()
class
Collect
(
object
):
"""Collect data from the loader relevant to the specific task.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img", "proposals", "gt_bboxes",
"gt_bboxes_ignore", "gt_labels", and/or "gt_masks".
The "img_meta" item is always populated. The contents of the "img_meta"
dictionary depends on "meta_keys". By default this includes:
- "img_shape": shape of the image input to the network as a tuple
\
(h, w, c). Note that images may be zero padded on the
\
bottom/right if the batch tensor is larger than this shape.
- "scale_factor": a float indicating the preprocessing scale
- "flip": a boolean indicating if image flip transform was used
- "filename": path to the image file
- "ori_shape": original shape of the image as a tuple (h, w, c)
- "pad_shape": image shape after padding
- "img_norm_cfg": a dict of normalization information:
- mean - per channel mean subtraction
- std - per channel std divisor
- to_rgb - bool indicating if bgr was converted to rgb
Args:
keys (Sequence[str]): Keys of results to be collected in ``data``.
meta_keys (Sequence[str], optional): Meta keys to be converted to
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'img_norm_cfg')``
"""
def
__init__
(
self
,
keys
,
meta_keys
=
(
'filename'
,
'ori_filename'
,
'ori_shape'
,
'img_shape'
,
'pad_shape'
,
'scale_factor'
,
'flip'
,
'flip_direction'
,
'img_norm_cfg'
)):
self
.
keys
=
keys
self
.
meta_keys
=
meta_keys
def
__call__
(
self
,
results
):
"""Call function to collect keys in results. The keys in ``meta_keys``
will be converted to :obj:mmcv.DataContainer.
Args:
results (dict): Result dict contains the data to collect.
Returns:
dict: The result dict contains the following keys
- keys in``self.keys``
- ``img_metas``
"""
data
=
{}
img_meta
=
{}
for
key
in
self
.
meta_keys
:
img_meta
[
key
]
=
results
[
key
]
data
[
'img_metas'
]
=
DC
(
img_meta
,
cpu_only
=
True
)
for
key
in
self
.
keys
:
data
[
key
]
=
results
[
key
]
return
data
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
\
f
'(keys=
{
self
.
keys
}
, meta_keys=
{
self
.
meta_keys
}
)'
@
PIPELINES
.
register_module
()
class
WrapFieldsToLists
(
object
):
"""Wrap fields of the data dictionary into lists for evaluation.
This class can be used as a last step of a test or validation
pipeline for single image evaluation or inference.
Example:
>>> test_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
>>> dict(type='Pad', size_divisor=32),
>>> dict(type='ImageToTensor', keys=['img']),
>>> dict(type='Collect', keys=['img']),
>>> dict(type='WrapFieldsToLists')
>>> ]
"""
def
__call__
(
self
,
results
):
"""Call function to wrap fields into lists.
Args:
results (dict): Result dict contains the data to wrap.
Returns:
dict: The result dict where value of ``self.keys`` are wrapped
\
into list.
"""
# Wrap dict fields into lists
for
key
,
val
in
results
.
items
():
results
[
key
]
=
[
val
]
return
results
def
__repr__
(
self
):
return
f
'
{
self
.
__class__
.
__name__
}
()'
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/instaboost.py
0 → 100644
View file @
142dcf29
import
numpy
as
np
from
..builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
InstaBoost
(
object
):
r
"""Data augmentation method in `InstaBoost: Boosting Instance
Segmentation Via Probability Map Guided Copy-Pasting
<https://arxiv.org/abs/1908.07801>`_.
Refer to https://github.com/GothicAi/Instaboost for implementation details.
"""
def
__init__
(
self
,
action_candidate
=
(
'normal'
,
'horizontal'
,
'skip'
),
action_prob
=
(
1
,
0
,
0
),
scale
=
(
0.8
,
1.2
),
dx
=
15
,
dy
=
15
,
theta
=
(
-
1
,
1
),
color_prob
=
0.5
,
hflag
=
False
,
aug_ratio
=
0.5
):
try
:
import
instaboostfast
as
instaboost
except
ImportError
:
raise
ImportError
(
'Please run "pip install instaboostfast" '
'to install instaboostfast first for instaboost augmentation.'
)
self
.
cfg
=
instaboost
.
InstaBoostConfig
(
action_candidate
,
action_prob
,
scale
,
dx
,
dy
,
theta
,
color_prob
,
hflag
)
self
.
aug_ratio
=
aug_ratio
def
_load_anns
(
self
,
results
):
labels
=
results
[
'ann_info'
][
'labels'
]
masks
=
results
[
'ann_info'
][
'masks'
]
bboxes
=
results
[
'ann_info'
][
'bboxes'
]
n
=
len
(
labels
)
anns
=
[]
for
i
in
range
(
n
):
label
=
labels
[
i
]
bbox
=
bboxes
[
i
]
mask
=
masks
[
i
]
x1
,
y1
,
x2
,
y2
=
bbox
# assert (x2 - x1) >= 1 and (y2 - y1) >= 1
bbox
=
[
x1
,
y1
,
x2
-
x1
,
y2
-
y1
]
anns
.
append
({
'category_id'
:
label
,
'segmentation'
:
mask
,
'bbox'
:
bbox
})
return
anns
def
_parse_anns
(
self
,
results
,
anns
,
img
):
gt_bboxes
=
[]
gt_labels
=
[]
gt_masks_ann
=
[]
for
ann
in
anns
:
x1
,
y1
,
w
,
h
=
ann
[
'bbox'
]
# TODO: more essential bug need to be fixed in instaboost
if
w
<=
0
or
h
<=
0
:
continue
bbox
=
[
x1
,
y1
,
x1
+
w
,
y1
+
h
]
gt_bboxes
.
append
(
bbox
)
gt_labels
.
append
(
ann
[
'category_id'
])
gt_masks_ann
.
append
(
ann
[
'segmentation'
])
gt_bboxes
=
np
.
array
(
gt_bboxes
,
dtype
=
np
.
float32
)
gt_labels
=
np
.
array
(
gt_labels
,
dtype
=
np
.
int64
)
results
[
'ann_info'
][
'labels'
]
=
gt_labels
results
[
'ann_info'
][
'bboxes'
]
=
gt_bboxes
results
[
'ann_info'
][
'masks'
]
=
gt_masks_ann
results
[
'img'
]
=
img
return
results
def
__call__
(
self
,
results
):
img
=
results
[
'img'
]
orig_type
=
img
.
dtype
anns
=
self
.
_load_anns
(
results
)
if
np
.
random
.
choice
([
0
,
1
],
p
=
[
1
-
self
.
aug_ratio
,
self
.
aug_ratio
]):
try
:
import
instaboostfast
as
instaboost
except
ImportError
:
raise
ImportError
(
'Please run "pip install instaboostfast" '
'to install instaboostfast first.'
)
anns
,
img
=
instaboost
.
get_new_data
(
anns
,
img
.
astype
(
np
.
uint8
),
self
.
cfg
,
background
=
None
)
results
=
self
.
_parse_anns
(
results
,
anns
,
img
.
astype
(
orig_type
))
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(cfg=
{
self
.
cfg
}
, aug_ratio=
{
self
.
aug_ratio
}
)'
return
repr_str
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/loading.py
0 → 100644
View file @
142dcf29
import
os.path
as
osp
import
mmcv
import
numpy
as
np
import
pycocotools.mask
as
maskUtils
from
mmdet.core
import
BitmapMasks
,
PolygonMasks
from
..builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
LoadImageFromFile
(
object
):
"""Load an image from file.
Required keys are "img_prefix" and "img_info" (a dict that must contain the
key "filename"). Added or updated keys are "filename", "img", "img_shape",
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
Defaults to 'color'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
def
__init__
(
self
,
to_float32
=
False
,
color_type
=
'color'
,
file_client_args
=
dict
(
backend
=
'disk'
)):
self
.
to_float32
=
to_float32
self
.
color_type
=
color_type
self
.
file_client_args
=
file_client_args
.
copy
()
self
.
file_client
=
None
def
__call__
(
self
,
results
):
"""Call functions to load image and get image meta information.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded image and meta information.
"""
if
self
.
file_client
is
None
:
self
.
file_client
=
mmcv
.
FileClient
(
**
self
.
file_client_args
)
if
results
[
'img_prefix'
]
is
not
None
:
filename
=
osp
.
join
(
results
[
'img_prefix'
],
results
[
'img_info'
][
'filename'
])
else
:
filename
=
results
[
'img_info'
][
'filename'
]
img_bytes
=
self
.
file_client
.
get
(
filename
)
img
=
mmcv
.
imfrombytes
(
img_bytes
,
flag
=
self
.
color_type
)
if
self
.
to_float32
:
img
=
img
.
astype
(
np
.
float32
)
results
[
'filename'
]
=
filename
results
[
'ori_filename'
]
=
results
[
'img_info'
][
'filename'
]
results
[
'img'
]
=
img
results
[
'img_shape'
]
=
img
.
shape
results
[
'ori_shape'
]
=
img
.
shape
results
[
'img_fields'
]
=
[
'img'
]
return
results
def
__repr__
(
self
):
repr_str
=
(
f
'
{
self
.
__class__
.
__name__
}
('
f
'to_float32=
{
self
.
to_float32
}
, '
f
"color_type='
{
self
.
color_type
}
', "
f
'file_client_args=
{
self
.
file_client_args
}
)'
)
return
repr_str
@
PIPELINES
.
register_module
()
class
LoadImageFromWebcam
(
LoadImageFromFile
):
"""Load an image from webcam.
Similar with :obj:`LoadImageFromFile`, but the image read from webcam is in
``results['img']``.
"""
def
__call__
(
self
,
results
):
"""Call functions to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""
img
=
results
[
'img'
]
if
self
.
to_float32
:
img
=
img
.
astype
(
np
.
float32
)
results
[
'filename'
]
=
None
results
[
'ori_filename'
]
=
None
results
[
'img'
]
=
img
results
[
'img_shape'
]
=
img
.
shape
results
[
'ori_shape'
]
=
img
.
shape
results
[
'img_fields'
]
=
[
'img'
]
return
results
@
PIPELINES
.
register_module
()
class
LoadMultiChannelImageFromFiles
(
object
):
"""Load multi-channel images from a list of separate channel files.
Required keys are "img_prefix" and "img_info" (a dict that must contain the
key "filename", which is expected to be a list of filenames).
Added or updated keys are "filename", "img", "img_shape",
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
Defaults to 'color'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
def
__init__
(
self
,
to_float32
=
False
,
color_type
=
'unchanged'
,
file_client_args
=
dict
(
backend
=
'disk'
)):
self
.
to_float32
=
to_float32
self
.
color_type
=
color_type
self
.
file_client_args
=
file_client_args
.
copy
()
self
.
file_client
=
None
def
__call__
(
self
,
results
):
"""Call functions to load multiple images and get images meta
information.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded images and meta information.
"""
if
self
.
file_client
is
None
:
self
.
file_client
=
mmcv
.
FileClient
(
**
self
.
file_client_args
)
if
results
[
'img_prefix'
]
is
not
None
:
filename
=
[
osp
.
join
(
results
[
'img_prefix'
],
fname
)
for
fname
in
results
[
'img_info'
][
'filename'
]
]
else
:
filename
=
results
[
'img_info'
][
'filename'
]
img
=
[]
for
name
in
filename
:
img_bytes
=
self
.
file_client
.
get
(
name
)
img
.
append
(
mmcv
.
imfrombytes
(
img_bytes
,
flag
=
self
.
color_type
))
img
=
np
.
stack
(
img
,
axis
=-
1
)
if
self
.
to_float32
:
img
=
img
.
astype
(
np
.
float32
)
results
[
'filename'
]
=
filename
results
[
'ori_filename'
]
=
results
[
'img_info'
][
'filename'
]
results
[
'img'
]
=
img
results
[
'img_shape'
]
=
img
.
shape
results
[
'ori_shape'
]
=
img
.
shape
# Set initial values for default meta_keys
results
[
'pad_shape'
]
=
img
.
shape
results
[
'scale_factor'
]
=
1.0
num_channels
=
1
if
len
(
img
.
shape
)
<
3
else
img
.
shape
[
2
]
results
[
'img_norm_cfg'
]
=
dict
(
mean
=
np
.
zeros
(
num_channels
,
dtype
=
np
.
float32
),
std
=
np
.
ones
(
num_channels
,
dtype
=
np
.
float32
),
to_rgb
=
False
)
return
results
def
__repr__
(
self
):
repr_str
=
(
f
'
{
self
.
__class__
.
__name__
}
('
f
'to_float32=
{
self
.
to_float32
}
, '
f
"color_type='
{
self
.
color_type
}
', "
f
'file_client_args=
{
self
.
file_client_args
}
)'
)
return
repr_str
@
PIPELINES
.
register_module
()
class
LoadAnnotations
(
object
):
"""Load mutiple types of annotations.
Args:
with_bbox (bool): Whether to parse and load the bbox annotation.
Default: True.
with_label (bool): Whether to parse and load the label annotation.
Default: True.
with_mask (bool): Whether to parse and load the mask annotation.
Default: False.
with_seg (bool): Whether to parse and load the semantic segmentation
annotation. Default: False.
poly2mask (bool): Whether to convert the instance masks from polygons
to bitmaps. Default: True.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
def
__init__
(
self
,
with_bbox
=
True
,
with_label
=
True
,
with_mask
=
False
,
with_seg
=
False
,
poly2mask
=
True
,
file_client_args
=
dict
(
backend
=
'disk'
)):
self
.
with_bbox
=
with_bbox
self
.
with_label
=
with_label
self
.
with_mask
=
with_mask
self
.
with_seg
=
with_seg
self
.
poly2mask
=
poly2mask
self
.
file_client_args
=
file_client_args
.
copy
()
self
.
file_client
=
None
def
_load_bboxes
(
self
,
results
):
"""Private function to load bounding box annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded bounding box annotations.
"""
ann_info
=
results
[
'ann_info'
]
results
[
'gt_bboxes'
]
=
ann_info
[
'bboxes'
].
copy
()
gt_bboxes_ignore
=
ann_info
.
get
(
'bboxes_ignore'
,
None
)
if
gt_bboxes_ignore
is
not
None
:
results
[
'gt_bboxes_ignore'
]
=
gt_bboxes_ignore
.
copy
()
results
[
'bbox_fields'
].
append
(
'gt_bboxes_ignore'
)
results
[
'bbox_fields'
].
append
(
'gt_bboxes'
)
return
results
def
_load_labels
(
self
,
results
):
"""Private function to load label annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded label annotations.
"""
results
[
'gt_labels'
]
=
results
[
'ann_info'
][
'labels'
].
copy
()
return
results
def
_poly2mask
(
self
,
mask_ann
,
img_h
,
img_w
):
"""Private function to convert masks represented with polygon to
bitmaps.
Args:
mask_ann (list | dict): Polygon mask annotation input.
img_h (int): The height of output mask.
img_w (int): The width of output mask.
Returns:
numpy.ndarray: The decode bitmap mask of shape (img_h, img_w).
"""
if
isinstance
(
mask_ann
,
list
):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles
=
maskUtils
.
frPyObjects
(
mask_ann
,
img_h
,
img_w
)
rle
=
maskUtils
.
merge
(
rles
)
elif
isinstance
(
mask_ann
[
'counts'
],
list
):
# uncompressed RLE
rle
=
maskUtils
.
frPyObjects
(
mask_ann
,
img_h
,
img_w
)
else
:
# rle
rle
=
mask_ann
mask
=
maskUtils
.
decode
(
rle
)
return
mask
def
process_polygons
(
self
,
polygons
):
"""Convert polygons to list of ndarray and filter invalid polygons.
Args:
polygons (list[list]): Polygons of one instance.
Returns:
list[numpy.ndarray]: Processed polygons.
"""
polygons
=
[
np
.
array
(
p
)
for
p
in
polygons
]
valid_polygons
=
[]
for
polygon
in
polygons
:
if
len
(
polygon
)
%
2
==
0
and
len
(
polygon
)
>=
6
:
valid_polygons
.
append
(
polygon
)
return
valid_polygons
def
_load_masks
(
self
,
results
):
"""Private function to load mask annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded mask annotations.
If ``self.poly2mask`` is set ``True``, `gt_mask` will contain
:obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.
"""
h
,
w
=
results
[
'img_info'
][
'height'
],
results
[
'img_info'
][
'width'
]
gt_masks
=
results
[
'ann_info'
][
'masks'
]
if
self
.
poly2mask
:
gt_masks
=
BitmapMasks
(
[
self
.
_poly2mask
(
mask
,
h
,
w
)
for
mask
in
gt_masks
],
h
,
w
)
else
:
gt_masks
=
PolygonMasks
(
[
self
.
process_polygons
(
polygons
)
for
polygons
in
gt_masks
],
h
,
w
)
results
[
'gt_masks'
]
=
gt_masks
results
[
'mask_fields'
].
append
(
'gt_masks'
)
return
results
def
_load_semantic_seg
(
self
,
results
):
"""Private function to load semantic segmentation annotations.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: The dict contains loaded semantic segmentation annotations.
"""
if
self
.
file_client
is
None
:
self
.
file_client
=
mmcv
.
FileClient
(
**
self
.
file_client_args
)
filename
=
osp
.
join
(
results
[
'seg_prefix'
],
results
[
'ann_info'
][
'seg_map'
])
img_bytes
=
self
.
file_client
.
get
(
filename
)
results
[
'gt_semantic_seg'
]
=
mmcv
.
imfrombytes
(
img_bytes
,
flag
=
'unchanged'
).
squeeze
()
results
[
'seg_fields'
].
append
(
'gt_semantic_seg'
)
return
results
def
__call__
(
self
,
results
):
"""Call function to load multiple types annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded bounding box, label, mask and
semantic segmentation annotations.
"""
if
self
.
with_bbox
:
results
=
self
.
_load_bboxes
(
results
)
if
results
is
None
:
return
None
if
self
.
with_label
:
results
=
self
.
_load_labels
(
results
)
if
self
.
with_mask
:
results
=
self
.
_load_masks
(
results
)
if
self
.
with_seg
:
results
=
self
.
_load_semantic_seg
(
results
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(with_bbox=
{
self
.
with_bbox
}
, '
repr_str
+=
f
'with_label=
{
self
.
with_label
}
, '
repr_str
+=
f
'with_mask=
{
self
.
with_mask
}
, '
repr_str
+=
f
'with_seg=
{
self
.
with_seg
}
, '
repr_str
+=
f
'poly2mask=
{
self
.
poly2mask
}
, '
repr_str
+=
f
'poly2mask=
{
self
.
file_client_args
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
LoadProposals
(
object
):
"""Load proposal pipeline.
Required key is "proposals". Updated keys are "proposals", "bbox_fields".
Args:
num_max_proposals (int, optional): Maximum number of proposals to load.
If not specified, all proposals will be loaded.
"""
def
__init__
(
self
,
num_max_proposals
=
None
):
self
.
num_max_proposals
=
num_max_proposals
def
__call__
(
self
,
results
):
"""Call function to load proposals from file.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded proposal annotations.
"""
proposals
=
results
[
'proposals'
]
if
proposals
.
shape
[
1
]
not
in
(
4
,
5
):
raise
AssertionError
(
'proposals should have shapes (n, 4) or (n, 5), '
f
'but found
{
proposals
.
shape
}
'
)
proposals
=
proposals
[:,
:
4
]
if
self
.
num_max_proposals
is
not
None
:
proposals
=
proposals
[:
self
.
num_max_proposals
]
if
len
(
proposals
)
==
0
:
proposals
=
np
.
array
([[
0
,
0
,
0
,
0
]],
dtype
=
np
.
float32
)
results
[
'proposals'
]
=
proposals
results
[
'bbox_fields'
].
append
(
'proposals'
)
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
\
f
'(num_max_proposals=
{
self
.
num_max_proposals
}
)'
@
PIPELINES
.
register_module
()
class
FilterAnnotations
(
object
):
"""Filter invalid annotations.
Args:
min_gt_bbox_wh (tuple[int]): Minimum width and height of ground truth
boxes.
"""
def
__init__
(
self
,
min_gt_bbox_wh
):
# TODO: add more filter options
self
.
min_gt_bbox_wh
=
min_gt_bbox_wh
def
__call__
(
self
,
results
):
assert
'gt_bboxes'
in
results
gt_bboxes
=
results
[
'gt_bboxes'
]
w
=
gt_bboxes
[:,
2
]
-
gt_bboxes
[:,
0
]
h
=
gt_bboxes
[:,
3
]
-
gt_bboxes
[:,
1
]
keep
=
(
w
>
self
.
min_gt_bbox_wh
[
0
])
&
(
h
>
self
.
min_gt_bbox_wh
[
1
])
if
not
keep
.
any
():
return
None
else
:
keys
=
(
'gt_bboxes'
,
'gt_labels'
,
'gt_masks'
,
'gt_semantic_seg'
)
for
key
in
keys
:
if
key
in
results
:
results
[
key
]
=
results
[
key
][
keep
]
return
results
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/test_time_aug.py
0 → 100644
View file @
142dcf29
import
warnings
import
mmcv
from
..builder
import
PIPELINES
from
.compose
import
Compose
@
PIPELINES
.
register_module
()
class
MultiScaleFlipAug
(
object
):
"""Test-time augmentation with multiple scales and flipping.
An example configuration is as followed:
.. code-block::
img_scale=[(1333, 400), (1333, 800)],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
]
After MultiScaleFLipAug with above configuration, the results are wrapped
into lists of the same length as followed:
.. code-block::
dict(
img=[...],
img_shape=[...],
scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
flip=[False, True, False, True]
...
)
Args:
transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (tuple | list[tuple] | None): Images scales for resizing.
scale_factor (float | list[float] | None): Scale factors for resizing.
flip (bool): Whether apply flip augmentation. Default: False.
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal" and "vertical". If flip_direction is list,
multiple flip augmentations will be applied.
It has no effect when flip == False. Default: "horizontal".
"""
def
__init__
(
self
,
transforms
,
img_scale
=
None
,
scale_factor
=
None
,
flip
=
False
,
flip_direction
=
'horizontal'
):
self
.
transforms
=
Compose
(
transforms
)
assert
(
img_scale
is
None
)
^
(
scale_factor
is
None
),
(
'Must have but only one variable can be setted'
)
if
img_scale
is
not
None
:
self
.
img_scale
=
img_scale
if
isinstance
(
img_scale
,
list
)
else
[
img_scale
]
self
.
scale_key
=
'scale'
assert
mmcv
.
is_list_of
(
self
.
img_scale
,
tuple
)
else
:
self
.
img_scale
=
scale_factor
if
isinstance
(
scale_factor
,
list
)
else
[
scale_factor
]
self
.
scale_key
=
'scale_factor'
self
.
flip
=
flip
self
.
flip_direction
=
flip_direction
if
isinstance
(
flip_direction
,
list
)
else
[
flip_direction
]
assert
mmcv
.
is_list_of
(
self
.
flip_direction
,
str
)
if
not
self
.
flip
and
self
.
flip_direction
!=
[
'horizontal'
]:
warnings
.
warn
(
'flip_direction has no effect when flip is set to False'
)
if
(
self
.
flip
and
not
any
([
t
[
'type'
]
==
'RandomFlip'
for
t
in
transforms
])):
warnings
.
warn
(
'flip has no effect when RandomFlip is not in transforms'
)
def
__call__
(
self
,
results
):
"""Call function to apply test time augment transforms on results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict[str: list]: The augmented data, where each value is wrapped
into a list.
"""
aug_data
=
[]
flip_args
=
[(
False
,
None
)]
if
self
.
flip
:
flip_args
+=
[(
True
,
direction
)
for
direction
in
self
.
flip_direction
]
for
scale
in
self
.
img_scale
:
for
flip
,
direction
in
flip_args
:
_results
=
results
.
copy
()
_results
[
self
.
scale_key
]
=
scale
_results
[
'flip'
]
=
flip
_results
[
'flip_direction'
]
=
direction
data
=
self
.
transforms
(
_results
)
aug_data
.
append
(
data
)
# list of dict to dict of list
aug_data_dict
=
{
key
:
[]
for
key
in
aug_data
[
0
]}
for
data
in
aug_data
:
for
key
,
val
in
data
.
items
():
aug_data_dict
[
key
].
append
(
val
)
return
aug_data_dict
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(transforms=
{
self
.
transforms
}
, '
repr_str
+=
f
'img_scale=
{
self
.
img_scale
}
, flip=
{
self
.
flip
}
, '
repr_str
+=
f
'flip_direction=
{
self
.
flip_direction
}
)'
return
repr_str
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/pipelines/transforms.py
0 → 100644
View file @
142dcf29
import
inspect
import
mmcv
import
numpy
as
np
from
numpy
import
random
from
mmdet.core
import
PolygonMasks
from
mmdet.core.evaluation.bbox_overlaps
import
bbox_overlaps
from
..builder
import
PIPELINES
try
:
from
imagecorruptions
import
corrupt
except
ImportError
:
corrupt
=
None
try
:
import
albumentations
from
albumentations
import
Compose
except
ImportError
:
albumentations
=
None
Compose
=
None
@
PIPELINES
.
register_module
()
class
Resize
(
object
):
"""Resize images & bbox & mask.
This transform resizes the input image to some scale. Bboxes and masks are
then resized with the same scale factor. If the input dict contains the key
"scale", then the scale in the input dict is used, otherwise the specified
scale in the init method is used. If the input dict contains the key
"scale_factor" (if MultiScaleFlipAug does not give img_scale but
scale_factor), the actual scale will be computed by image shape and
scale_factor.
`img_scale` can either be a tuple (single-scale) or a list of tuple
(multi-scale). There are 3 multiscale modes:
- ``ratio_range is not None``: randomly sample a ratio from the ratio
\
range and multiply it with the image scale.
- ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly
\
sample a scale from the multiscale range.
- ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly
\
sample a scale from multiple scales.
Args:
img_scale (tuple or list[tuple]): Images scales for resizing.
multiscale_mode (str): Either "range" or "value".
ratio_range (tuple[float]): (min_ratio, max_ratio)
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image.
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
override (bool, optional): Whether to override `scale` and
`scale_factor` so as to call resize twice. Default False. If True,
after the first resizing, the existed `scale` and `scale_factor`
will be ignored so the second resizing can be allowed.
This option is a work-around for multiple times of resize in DETR.
Defaults to False.
"""
def
__init__
(
self
,
img_scale
=
None
,
multiscale_mode
=
'range'
,
ratio_range
=
None
,
keep_ratio
=
True
,
bbox_clip_border
=
True
,
backend
=
'cv2'
,
override
=
False
):
if
img_scale
is
None
:
self
.
img_scale
=
None
else
:
if
isinstance
(
img_scale
,
list
):
self
.
img_scale
=
img_scale
else
:
self
.
img_scale
=
[
img_scale
]
assert
mmcv
.
is_list_of
(
self
.
img_scale
,
tuple
)
if
ratio_range
is
not
None
:
# mode 1: given a scale and a range of image ratio
assert
len
(
self
.
img_scale
)
==
1
else
:
# mode 2: given multiple scales or a range of scales
assert
multiscale_mode
in
[
'value'
,
'range'
]
self
.
backend
=
backend
self
.
multiscale_mode
=
multiscale_mode
self
.
ratio_range
=
ratio_range
self
.
keep_ratio
=
keep_ratio
# TODO: refactor the override option in Resize
self
.
override
=
override
self
.
bbox_clip_border
=
bbox_clip_border
@
staticmethod
def
random_select
(
img_scales
):
"""Randomly select an img_scale from given candidates.
Args:
img_scales (list[tuple]): Images scales for selection.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
\
where ``img_scale`` is the selected image scale and
\
``scale_idx`` is the selected index in the given candidates.
"""
assert
mmcv
.
is_list_of
(
img_scales
,
tuple
)
scale_idx
=
np
.
random
.
randint
(
len
(
img_scales
))
img_scale
=
img_scales
[
scale_idx
]
return
img_scale
,
scale_idx
@
staticmethod
def
random_sample
(
img_scales
):
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
Args:
img_scales (list[tuple]): Images scale range for sampling.
There must be two tuples in img_scales, which specify the lower
and uper bound of image scales.
Returns:
(tuple, None): Returns a tuple ``(img_scale, None)``, where
\
``img_scale`` is sampled scale and None is just a placeholder
\
to be consistent with :func:`random_select`.
"""
assert
mmcv
.
is_list_of
(
img_scales
,
tuple
)
and
len
(
img_scales
)
==
2
img_scale_long
=
[
max
(
s
)
for
s
in
img_scales
]
img_scale_short
=
[
min
(
s
)
for
s
in
img_scales
]
long_edge
=
np
.
random
.
randint
(
min
(
img_scale_long
),
max
(
img_scale_long
)
+
1
)
short_edge
=
np
.
random
.
randint
(
min
(
img_scale_short
),
max
(
img_scale_short
)
+
1
)
img_scale
=
(
long_edge
,
short_edge
)
return
img_scale
,
None
@
staticmethod
def
random_sample_ratio
(
img_scale
,
ratio_range
):
"""Randomly sample an img_scale when ``ratio_range`` is specified.
A ratio will be randomly sampled from the range specified by
``ratio_range``. Then it would be multiplied with ``img_scale`` to
generate sampled scale.
Args:
img_scale (tuple): Images scale base to multiply with ratio.
ratio_range (tuple[float]): The minimum and maximum ratio to scale
the ``img_scale``.
Returns:
(tuple, None): Returns a tuple ``(scale, None)``, where
\
``scale`` is sampled ratio multiplied with ``img_scale`` and
\
None is just a placeholder to be consistent with
\
:func:`random_select`.
"""
assert
isinstance
(
img_scale
,
tuple
)
and
len
(
img_scale
)
==
2
min_ratio
,
max_ratio
=
ratio_range
assert
min_ratio
<=
max_ratio
ratio
=
np
.
random
.
random_sample
()
*
(
max_ratio
-
min_ratio
)
+
min_ratio
scale
=
int
(
img_scale
[
0
]
*
ratio
),
int
(
img_scale
[
1
]
*
ratio
)
return
scale
,
None
def
_random_scale
(
self
,
results
):
"""Randomly sample an img_scale according to ``ratio_range`` and
``multiscale_mode``.
If ``ratio_range`` is specified, a ratio will be sampled and be
multiplied with ``img_scale``.
If multiple scales are specified by ``img_scale``, a scale will be
sampled according to ``multiscale_mode``.
Otherwise, single scale will be used.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: Two new keys 'scale` and 'scale_idx` are added into
\
``results``, which would be used by subsequent pipelines.
"""
if
self
.
ratio_range
is
not
None
:
scale
,
scale_idx
=
self
.
random_sample_ratio
(
self
.
img_scale
[
0
],
self
.
ratio_range
)
elif
len
(
self
.
img_scale
)
==
1
:
scale
,
scale_idx
=
self
.
img_scale
[
0
],
0
elif
self
.
multiscale_mode
==
'range'
:
scale
,
scale_idx
=
self
.
random_sample
(
self
.
img_scale
)
elif
self
.
multiscale_mode
==
'value'
:
scale
,
scale_idx
=
self
.
random_select
(
self
.
img_scale
)
else
:
raise
NotImplementedError
results
[
'scale'
]
=
scale
results
[
'scale_idx'
]
=
scale_idx
def
_resize_img
(
self
,
results
):
"""Resize images with ``results['scale']``."""
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
if
self
.
keep_ratio
:
img
,
scale_factor
=
mmcv
.
imrescale
(
results
[
key
],
results
[
'scale'
],
return_scale
=
True
,
backend
=
self
.
backend
)
# the w_scale and h_scale has minor difference
# a real fix should be done in the mmcv.imrescale in the future
new_h
,
new_w
=
img
.
shape
[:
2
]
h
,
w
=
results
[
key
].
shape
[:
2
]
w_scale
=
new_w
/
w
h_scale
=
new_h
/
h
else
:
img
,
w_scale
,
h_scale
=
mmcv
.
imresize
(
results
[
key
],
results
[
'scale'
],
return_scale
=
True
,
backend
=
self
.
backend
)
results
[
key
]
=
img
scale_factor
=
np
.
array
([
w_scale
,
h_scale
,
w_scale
,
h_scale
],
dtype
=
np
.
float32
)
results
[
'img_shape'
]
=
img
.
shape
# in case that there is no padding
results
[
'pad_shape'
]
=
img
.
shape
results
[
'scale_factor'
]
=
scale_factor
results
[
'keep_ratio'
]
=
self
.
keep_ratio
def
_resize_bboxes
(
self
,
results
):
"""Resize bounding boxes with ``results['scale_factor']``."""
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
bboxes
=
results
[
key
]
*
results
[
'scale_factor'
]
if
self
.
bbox_clip_border
:
img_shape
=
results
[
'img_shape'
]
bboxes
[:,
0
::
2
]
=
np
.
clip
(
bboxes
[:,
0
::
2
],
0
,
img_shape
[
1
])
bboxes
[:,
1
::
2
]
=
np
.
clip
(
bboxes
[:,
1
::
2
],
0
,
img_shape
[
0
])
results
[
key
]
=
bboxes
def
_resize_masks
(
self
,
results
):
"""Resize masks with ``results['scale']``"""
for
key
in
results
.
get
(
'mask_fields'
,
[]):
if
results
[
key
]
is
None
:
continue
if
self
.
keep_ratio
:
results
[
key
]
=
results
[
key
].
rescale
(
results
[
'scale'
])
else
:
results
[
key
]
=
results
[
key
].
resize
(
results
[
'img_shape'
][:
2
])
def
_resize_seg
(
self
,
results
):
"""Resize semantic segmentation map with ``results['scale']``."""
for
key
in
results
.
get
(
'seg_fields'
,
[]):
if
self
.
keep_ratio
:
gt_seg
=
mmcv
.
imrescale
(
results
[
key
],
results
[
'scale'
],
interpolation
=
'nearest'
,
backend
=
self
.
backend
)
else
:
gt_seg
=
mmcv
.
imresize
(
results
[
key
],
results
[
'scale'
],
interpolation
=
'nearest'
,
backend
=
self
.
backend
)
results
[
'gt_semantic_seg'
]
=
gt_seg
def
__call__
(
self
,
results
):
"""Call function to resize images, bounding boxes, masks, semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
\
'keep_ratio' keys are added into result dict.
"""
if
'scale'
not
in
results
:
if
'scale_factor'
in
results
:
img_shape
=
results
[
'img'
].
shape
[:
2
]
scale_factor
=
results
[
'scale_factor'
]
assert
isinstance
(
scale_factor
,
float
)
results
[
'scale'
]
=
tuple
(
[
int
(
x
*
scale_factor
)
for
x
in
img_shape
][::
-
1
])
else
:
self
.
_random_scale
(
results
)
else
:
if
not
self
.
override
:
assert
'scale_factor'
not
in
results
,
(
'scale and scale_factor cannot be both set.'
)
else
:
results
.
pop
(
'scale'
)
if
'scale_factor'
in
results
:
results
.
pop
(
'scale_factor'
)
self
.
_random_scale
(
results
)
self
.
_resize_img
(
results
)
self
.
_resize_bboxes
(
results
)
self
.
_resize_masks
(
results
)
self
.
_resize_seg
(
results
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(img_scale=
{
self
.
img_scale
}
, '
repr_str
+=
f
'multiscale_mode=
{
self
.
multiscale_mode
}
, '
repr_str
+=
f
'ratio_range=
{
self
.
ratio_range
}
, '
repr_str
+=
f
'keep_ratio=
{
self
.
keep_ratio
}
, '
repr_str
+=
f
'bbox_clip_border=
{
self
.
bbox_clip_border
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
RandomFlip
(
object
):
"""Flip the image & bbox & mask.
If the input dict contains the key "flip", then the flag will be used,
otherwise it will be randomly decided by a ratio specified in the init
method.
When random flip is enabled, ``flip_ratio``/``direction`` can either be a
float/string or tuple of float/string. There are 3 flip modes:
- ``flip_ratio`` is float, ``direction`` is string: the image will be
``direction``ly flipped with probability of ``flip_ratio`` .
E.g., ``flip_ratio=0.5``, ``direction='horizontal'``,
then image will be horizontally flipped with probability of 0.5.
- ``flip_ratio`` is float, ``direction`` is list of string: the image wil
be ``direction[i]``ly flipped with probability of
``flip_ratio/len(direction)``.
E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``,
then image will be horizontally flipped with probability of 0.25,
vertically with probability of 0.25.
- ``flip_ratio`` is list of float, ``direction`` is list of string:
given ``len(flip_ratio) == len(direction)``, the image wil
be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``.
E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal',
'vertical']``, then image will be horizontally flipped with probability
of 0.3, vertically with probability of 0.5
Args:
flip_ratio (float | list[float], optional): The flipping probability.
Default: None.
direction(str | list[str], optional): The flipping direction. Options
are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'.
If input is a list, the length must equal ``flip_ratio``. Each
element in ``flip_ratio`` indicates the flip probability of
corresponding direction.
"""
def
__init__
(
self
,
flip_ratio
=
None
,
direction
=
'horizontal'
):
if
isinstance
(
flip_ratio
,
list
):
assert
mmcv
.
is_list_of
(
flip_ratio
,
float
)
assert
0
<=
sum
(
flip_ratio
)
<=
1
elif
isinstance
(
flip_ratio
,
float
):
assert
0
<=
flip_ratio
<=
1
elif
flip_ratio
is
None
:
pass
else
:
raise
ValueError
(
'flip_ratios must be None, float, '
'or list of float'
)
self
.
flip_ratio
=
flip_ratio
valid_directions
=
[
'horizontal'
,
'vertical'
,
'diagonal'
]
if
isinstance
(
direction
,
str
):
assert
direction
in
valid_directions
elif
isinstance
(
direction
,
list
):
assert
mmcv
.
is_list_of
(
direction
,
str
)
assert
set
(
direction
).
issubset
(
set
(
valid_directions
))
else
:
raise
ValueError
(
'direction must be either str or list of str'
)
self
.
direction
=
direction
if
isinstance
(
flip_ratio
,
list
):
assert
len
(
self
.
flip_ratio
)
==
len
(
self
.
direction
)
def
bbox_flip
(
self
,
bboxes
,
img_shape
,
direction
):
"""Flip bboxes horizontally.
Args:
bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
img_shape (tuple[int]): Image shape (height, width)
direction (str): Flip direction. Options are 'horizontal',
'vertical'.
Returns:
numpy.ndarray: Flipped bounding boxes.
"""
assert
bboxes
.
shape
[
-
1
]
%
4
==
0
flipped
=
bboxes
.
copy
()
if
direction
==
'horizontal'
:
w
=
img_shape
[
1
]
flipped
[...,
0
::
4
]
=
w
-
bboxes
[...,
2
::
4
]
flipped
[...,
2
::
4
]
=
w
-
bboxes
[...,
0
::
4
]
elif
direction
==
'vertical'
:
h
=
img_shape
[
0
]
flipped
[...,
1
::
4
]
=
h
-
bboxes
[...,
3
::
4
]
flipped
[...,
3
::
4
]
=
h
-
bboxes
[...,
1
::
4
]
elif
direction
==
'diagonal'
:
w
=
img_shape
[
1
]
h
=
img_shape
[
0
]
flipped
[...,
0
::
4
]
=
w
-
bboxes
[...,
2
::
4
]
flipped
[...,
1
::
4
]
=
h
-
bboxes
[...,
3
::
4
]
flipped
[...,
2
::
4
]
=
w
-
bboxes
[...,
0
::
4
]
flipped
[...,
3
::
4
]
=
h
-
bboxes
[...,
1
::
4
]
else
:
raise
ValueError
(
f
"Invalid flipping direction '
{
direction
}
'"
)
return
flipped
def
__call__
(
self
,
results
):
"""Call function to flip bounding boxes, masks, semantic segmentation
maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Flipped results, 'flip', 'flip_direction' keys are added
\
into result dict.
"""
if
'flip'
not
in
results
:
if
isinstance
(
self
.
direction
,
list
):
# None means non-flip
direction_list
=
self
.
direction
+
[
None
]
else
:
# None means non-flip
direction_list
=
[
self
.
direction
,
None
]
if
isinstance
(
self
.
flip_ratio
,
list
):
non_flip_ratio
=
1
-
sum
(
self
.
flip_ratio
)
flip_ratio_list
=
self
.
flip_ratio
+
[
non_flip_ratio
]
else
:
non_flip_ratio
=
1
-
self
.
flip_ratio
# exclude non-flip
single_ratio
=
self
.
flip_ratio
/
(
len
(
direction_list
)
-
1
)
flip_ratio_list
=
[
single_ratio
]
*
(
len
(
direction_list
)
-
1
)
+
[
non_flip_ratio
]
cur_dir
=
np
.
random
.
choice
(
direction_list
,
p
=
flip_ratio_list
)
results
[
'flip'
]
=
cur_dir
is
not
None
if
'flip_direction'
not
in
results
:
results
[
'flip_direction'
]
=
cur_dir
if
results
[
'flip'
]:
# flip image
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
results
[
key
]
=
mmcv
.
imflip
(
results
[
key
],
direction
=
results
[
'flip_direction'
])
# flip bboxes
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
results
[
key
]
=
self
.
bbox_flip
(
results
[
key
],
results
[
'img_shape'
],
results
[
'flip_direction'
])
# flip masks
for
key
in
results
.
get
(
'mask_fields'
,
[]):
results
[
key
]
=
results
[
key
].
flip
(
results
[
'flip_direction'
])
# flip segs
for
key
in
results
.
get
(
'seg_fields'
,
[]):
results
[
key
]
=
mmcv
.
imflip
(
results
[
key
],
direction
=
results
[
'flip_direction'
])
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(flip_ratio=
{
self
.
flip_ratio
}
)'
@
PIPELINES
.
register_module
()
class
Pad
(
object
):
"""Pad the image & mask.
There are two padding modes: (1) pad to a fixed size and (2) pad to the
minimum size that is divisible by some number.
Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
Args:
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (float, optional): Padding value, 0 by default.
"""
def
__init__
(
self
,
size
=
None
,
size_divisor
=
None
,
pad_val
=
0
):
self
.
size
=
size
self
.
size_divisor
=
size_divisor
self
.
pad_val
=
pad_val
# only one of size and size_divisor should be valid
assert
size
is
not
None
or
size_divisor
is
not
None
assert
size
is
None
or
size_divisor
is
None
def
_pad_img
(
self
,
results
):
"""Pad images according to ``self.size``."""
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
if
self
.
size
is
not
None
:
padded_img
=
mmcv
.
impad
(
results
[
key
],
shape
=
self
.
size
,
pad_val
=
self
.
pad_val
)
elif
self
.
size_divisor
is
not
None
:
padded_img
=
mmcv
.
impad_to_multiple
(
results
[
key
],
self
.
size_divisor
,
pad_val
=
self
.
pad_val
)
results
[
key
]
=
padded_img
results
[
'pad_shape'
]
=
padded_img
.
shape
results
[
'pad_fixed_size'
]
=
self
.
size
results
[
'pad_size_divisor'
]
=
self
.
size_divisor
def
_pad_masks
(
self
,
results
):
"""Pad masks according to ``results['pad_shape']``."""
pad_shape
=
results
[
'pad_shape'
][:
2
]
for
key
in
results
.
get
(
'mask_fields'
,
[]):
results
[
key
]
=
results
[
key
].
pad
(
pad_shape
,
pad_val
=
self
.
pad_val
)
def
_pad_seg
(
self
,
results
):
"""Pad semantic segmentation map according to
``results['pad_shape']``."""
for
key
in
results
.
get
(
'seg_fields'
,
[]):
results
[
key
]
=
mmcv
.
impad
(
results
[
key
],
shape
=
results
[
'pad_shape'
][:
2
])
def
__call__
(
self
,
results
):
"""Call function to pad images, masks, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
self
.
_pad_img
(
results
)
self
.
_pad_masks
(
results
)
self
.
_pad_seg
(
results
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(size=
{
self
.
size
}
, '
repr_str
+=
f
'size_divisor=
{
self
.
size_divisor
}
, '
repr_str
+=
f
'pad_val=
{
self
.
pad_val
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
Normalize
(
object
):
"""Normalize the image.
Added key is "img_norm_cfg".
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true.
"""
def
__init__
(
self
,
mean
,
std
,
to_rgb
=
True
):
self
.
mean
=
np
.
array
(
mean
,
dtype
=
np
.
float32
)
self
.
std
=
np
.
array
(
std
,
dtype
=
np
.
float32
)
self
.
to_rgb
=
to_rgb
def
__call__
(
self
,
results
):
"""Call function to normalize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Normalized results, 'img_norm_cfg' key is added into
result dict.
"""
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
results
[
key
]
=
mmcv
.
imnormalize
(
results
[
key
],
self
.
mean
,
self
.
std
,
self
.
to_rgb
)
results
[
'img_norm_cfg'
]
=
dict
(
mean
=
self
.
mean
,
std
=
self
.
std
,
to_rgb
=
self
.
to_rgb
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(mean=
{
self
.
mean
}
, std=
{
self
.
std
}
, to_rgb=
{
self
.
to_rgb
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
RandomCrop
(
object
):
"""Random crop the image & bboxes & masks.
The absolute `crop_size` is sampled based on `crop_type` and `image_size`,
then the cropped results are generated.
Args:
crop_size (tuple): The relative ratio or absolute pixels of
height and width.
crop_type (str, optional): one of "relative_range", "relative",
"absolute", "absolute_range". "relative" randomly crops
(h * crop_size[0], w * crop_size[1]) part from an input of size
(h, w). "relative_range" uniformly samples relative crop size from
range [crop_size[0], 1] and [crop_size[1], 1] for height and width
respectively. "absolute" crops from an input with absolute size
(crop_size[0], crop_size[1]). "absolute_range" uniformly samples
crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w
in range [crop_size[0], min(w, crop_size[1])]. Default "absolute".
allow_negative_crop (bool, optional): Whether to allow a crop that does
not contain any bbox area. Default False.
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
Note:
- If the image is smaller than the absolute crop size, return the
original image.
- The keys for bboxes, labels and masks must be aligned. That is,
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and
`gt_bboxes_ignore` corresponds to `gt_labels_ignore` and
`gt_masks_ignore`.
- If the crop does not contain any gt-bbox region and
`allow_negative_crop` is set to False, skip this image.
"""
def
__init__
(
self
,
crop_size
,
crop_type
=
'absolute'
,
allow_negative_crop
=
False
,
bbox_clip_border
=
True
):
if
crop_type
not
in
[
'relative_range'
,
'relative'
,
'absolute'
,
'absolute_range'
]:
raise
ValueError
(
f
'Invalid crop_type
{
crop_type
}
.'
)
if
crop_type
in
[
'absolute'
,
'absolute_range'
]:
assert
crop_size
[
0
]
>
0
and
crop_size
[
1
]
>
0
assert
isinstance
(
crop_size
[
0
],
int
)
and
isinstance
(
crop_size
[
1
],
int
)
else
:
assert
0
<
crop_size
[
0
]
<=
1
and
0
<
crop_size
[
1
]
<=
1
self
.
crop_size
=
crop_size
self
.
crop_type
=
crop_type
self
.
allow_negative_crop
=
allow_negative_crop
self
.
bbox_clip_border
=
bbox_clip_border
# The key correspondence from bboxes to labels and masks.
self
.
bbox2label
=
{
'gt_bboxes'
:
'gt_labels'
,
'gt_bboxes_ignore'
:
'gt_labels_ignore'
}
self
.
bbox2mask
=
{
'gt_bboxes'
:
'gt_masks'
,
'gt_bboxes_ignore'
:
'gt_masks_ignore'
}
def
_crop_data
(
self
,
results
,
crop_size
,
allow_negative_crop
):
"""Function to randomly crop images, bounding boxes, masks, semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
crop_size (tuple): Expected absolute size after cropping, (h, w).
allow_negative_crop (bool): Whether to allow a crop that does not
contain any bbox area. Default to False.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
"""
assert
crop_size
[
0
]
>
0
and
crop_size
[
1
]
>
0
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
img
=
results
[
key
]
margin_h
=
max
(
img
.
shape
[
0
]
-
crop_size
[
0
],
0
)
margin_w
=
max
(
img
.
shape
[
1
]
-
crop_size
[
1
],
0
)
offset_h
=
np
.
random
.
randint
(
0
,
margin_h
+
1
)
offset_w
=
np
.
random
.
randint
(
0
,
margin_w
+
1
)
crop_y1
,
crop_y2
=
offset_h
,
offset_h
+
crop_size
[
0
]
crop_x1
,
crop_x2
=
offset_w
,
offset_w
+
crop_size
[
1
]
# crop the image
img
=
img
[
crop_y1
:
crop_y2
,
crop_x1
:
crop_x2
,
...]
img_shape
=
img
.
shape
results
[
key
]
=
img
results
[
'img_shape'
]
=
img_shape
# crop bboxes accordingly and clip to the image boundary
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
# e.g. gt_bboxes and gt_bboxes_ignore
bbox_offset
=
np
.
array
([
offset_w
,
offset_h
,
offset_w
,
offset_h
],
dtype
=
np
.
float32
)
bboxes
=
results
[
key
]
-
bbox_offset
if
self
.
bbox_clip_border
:
bboxes
[:,
0
::
2
]
=
np
.
clip
(
bboxes
[:,
0
::
2
],
0
,
img_shape
[
1
])
bboxes
[:,
1
::
2
]
=
np
.
clip
(
bboxes
[:,
1
::
2
],
0
,
img_shape
[
0
])
valid_inds
=
(
bboxes
[:,
2
]
>
bboxes
[:,
0
])
&
(
bboxes
[:,
3
]
>
bboxes
[:,
1
])
# If the crop does not contain any gt-bbox area and
# allow_negative_crop is False, skip this image.
if
(
key
==
'gt_bboxes'
and
not
valid_inds
.
any
()
and
not
allow_negative_crop
):
return
None
results
[
key
]
=
bboxes
[
valid_inds
,
:]
# label fields. e.g. gt_labels and gt_labels_ignore
label_key
=
self
.
bbox2label
.
get
(
key
)
if
label_key
in
results
:
results
[
label_key
]
=
results
[
label_key
][
valid_inds
]
# mask fields, e.g. gt_masks and gt_masks_ignore
mask_key
=
self
.
bbox2mask
.
get
(
key
)
if
mask_key
in
results
:
results
[
mask_key
]
=
results
[
mask_key
][
valid_inds
.
nonzero
()[
0
]].
crop
(
np
.
asarray
([
crop_x1
,
crop_y1
,
crop_x2
,
crop_y2
]))
# crop semantic seg
for
key
in
results
.
get
(
'seg_fields'
,
[]):
results
[
key
]
=
results
[
key
][
crop_y1
:
crop_y2
,
crop_x1
:
crop_x2
]
return
results
def
_get_crop_size
(
self
,
image_size
):
"""Randomly generates the absolute crop size based on `crop_type` and
`image_size`.
Args:
image_size (tuple): (h, w).
Returns:
crop_size (tuple): (crop_h, crop_w) in absolute pixels.
"""
h
,
w
=
image_size
if
self
.
crop_type
==
'absolute'
:
return
(
min
(
self
.
crop_size
[
0
],
h
),
min
(
self
.
crop_size
[
1
],
w
))
elif
self
.
crop_type
==
'absolute_range'
:
assert
self
.
crop_size
[
0
]
<=
self
.
crop_size
[
1
]
crop_h
=
np
.
random
.
randint
(
min
(
h
,
self
.
crop_size
[
0
]),
min
(
h
,
self
.
crop_size
[
1
])
+
1
)
crop_w
=
np
.
random
.
randint
(
min
(
w
,
self
.
crop_size
[
0
]),
min
(
w
,
self
.
crop_size
[
1
])
+
1
)
return
crop_h
,
crop_w
elif
self
.
crop_type
==
'relative'
:
crop_h
,
crop_w
=
self
.
crop_size
return
int
(
h
*
crop_h
+
0.5
),
int
(
w
*
crop_w
+
0.5
)
elif
self
.
crop_type
==
'relative_range'
:
crop_size
=
np
.
asarray
(
self
.
crop_size
,
dtype
=
np
.
float32
)
crop_h
,
crop_w
=
crop_size
+
np
.
random
.
rand
(
2
)
*
(
1
-
crop_size
)
return
int
(
h
*
crop_h
+
0.5
),
int
(
w
*
crop_w
+
0.5
)
def
__call__
(
self
,
results
):
"""Call function to randomly crop images, bounding boxes, masks,
semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
"""
image_size
=
results
[
'img'
].
shape
[:
2
]
crop_size
=
self
.
_get_crop_size
(
image_size
)
results
=
self
.
_crop_data
(
results
,
crop_size
,
self
.
allow_negative_crop
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(crop_size=
{
self
.
crop_size
}
, '
repr_str
+=
f
'crop_type=
{
self
.
crop_type
}
, '
repr_str
+=
f
'allow_negative_crop=
{
self
.
allow_negative_crop
}
, '
repr_str
+=
f
'bbox_clip_border=
{
self
.
bbox_clip_border
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
SegRescale
(
object
):
"""Rescale semantic segmentation maps.
Args:
scale_factor (float): The scale factor of the final output.
backend (str): Image rescale backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
"""
def
__init__
(
self
,
scale_factor
=
1
,
backend
=
'cv2'
):
self
.
scale_factor
=
scale_factor
self
.
backend
=
backend
def
__call__
(
self
,
results
):
"""Call function to scale the semantic segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with semantic segmentation map scaled.
"""
for
key
in
results
.
get
(
'seg_fields'
,
[]):
if
self
.
scale_factor
!=
1
:
results
[
key
]
=
mmcv
.
imrescale
(
results
[
key
],
self
.
scale_factor
,
interpolation
=
'nearest'
,
backend
=
self
.
backend
)
return
results
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(scale_factor=
{
self
.
scale_factor
}
)'
@
PIPELINES
.
register_module
()
class
PhotoMetricDistortion
(
object
):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def
__init__
(
self
,
brightness_delta
=
32
,
contrast_range
=
(
0.5
,
1.5
),
saturation_range
=
(
0.5
,
1.5
),
hue_delta
=
18
):
self
.
brightness_delta
=
brightness_delta
self
.
contrast_lower
,
self
.
contrast_upper
=
contrast_range
self
.
saturation_lower
,
self
.
saturation_upper
=
saturation_range
self
.
hue_delta
=
hue_delta
def
__call__
(
self
,
results
):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
if
'img_fields'
in
results
:
assert
results
[
'img_fields'
]
==
[
'img'
],
\
'Only single img_fields is allowed'
img
=
results
[
'img'
]
assert
img
.
dtype
==
np
.
float32
,
\
'PhotoMetricDistortion needs the input image of dtype np.float32,'
\
' please set "to_float32=True" in "LoadImageFromFile" pipeline'
# random brightness
if
random
.
randint
(
2
):
delta
=
random
.
uniform
(
-
self
.
brightness_delta
,
self
.
brightness_delta
)
img
+=
delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode
=
random
.
randint
(
2
)
if
mode
==
1
:
if
random
.
randint
(
2
):
alpha
=
random
.
uniform
(
self
.
contrast_lower
,
self
.
contrast_upper
)
img
*=
alpha
# convert color from BGR to HSV
img
=
mmcv
.
bgr2hsv
(
img
)
# random saturation
if
random
.
randint
(
2
):
img
[...,
1
]
*=
random
.
uniform
(
self
.
saturation_lower
,
self
.
saturation_upper
)
# random hue
if
random
.
randint
(
2
):
img
[...,
0
]
+=
random
.
uniform
(
-
self
.
hue_delta
,
self
.
hue_delta
)
img
[...,
0
][
img
[...,
0
]
>
360
]
-=
360
img
[...,
0
][
img
[...,
0
]
<
0
]
+=
360
# convert color from HSV to BGR
img
=
mmcv
.
hsv2bgr
(
img
)
# random contrast
if
mode
==
0
:
if
random
.
randint
(
2
):
alpha
=
random
.
uniform
(
self
.
contrast_lower
,
self
.
contrast_upper
)
img
*=
alpha
# randomly swap channels
if
random
.
randint
(
2
):
img
=
img
[...,
random
.
permutation
(
3
)]
results
[
'img'
]
=
img
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(
\n
brightness_delta=
{
self
.
brightness_delta
}
,
\n
'
repr_str
+=
'contrast_range='
repr_str
+=
f
'
{
(
self
.
contrast_lower
,
self
.
contrast_upper
)
}
,
\n
'
repr_str
+=
'saturation_range='
repr_str
+=
f
'
{
(
self
.
saturation_lower
,
self
.
saturation_upper
)
}
,
\n
'
repr_str
+=
f
'hue_delta=
{
self
.
hue_delta
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
Expand
(
object
):
"""Random expand the image & bboxes.
Randomly place the original image on a canvas of 'ratio' x original image
size filled with mean values. The ratio is in the range of ratio_range.
Args:
mean (tuple): mean value of dataset.
to_rgb (bool): if need to convert the order of mean to align with RGB.
ratio_range (tuple): range of expand ratio.
prob (float): probability of applying this transformation
"""
def
__init__
(
self
,
mean
=
(
0
,
0
,
0
),
to_rgb
=
True
,
ratio_range
=
(
1
,
4
),
seg_ignore_label
=
None
,
prob
=
0.5
):
self
.
to_rgb
=
to_rgb
self
.
ratio_range
=
ratio_range
if
to_rgb
:
self
.
mean
=
mean
[::
-
1
]
else
:
self
.
mean
=
mean
self
.
min_ratio
,
self
.
max_ratio
=
ratio_range
self
.
seg_ignore_label
=
seg_ignore_label
self
.
prob
=
prob
def
__call__
(
self
,
results
):
"""Call function to expand images, bounding boxes.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images, bounding boxes expanded
"""
if
random
.
uniform
(
0
,
1
)
>
self
.
prob
:
return
results
if
'img_fields'
in
results
:
assert
results
[
'img_fields'
]
==
[
'img'
],
\
'Only single img_fields is allowed'
img
=
results
[
'img'
]
h
,
w
,
c
=
img
.
shape
ratio
=
random
.
uniform
(
self
.
min_ratio
,
self
.
max_ratio
)
# speedup expand when meets large image
if
np
.
all
(
self
.
mean
==
self
.
mean
[
0
]):
expand_img
=
np
.
empty
((
int
(
h
*
ratio
),
int
(
w
*
ratio
),
c
),
img
.
dtype
)
expand_img
.
fill
(
self
.
mean
[
0
])
else
:
expand_img
=
np
.
full
((
int
(
h
*
ratio
),
int
(
w
*
ratio
),
c
),
self
.
mean
,
dtype
=
img
.
dtype
)
left
=
int
(
random
.
uniform
(
0
,
w
*
ratio
-
w
))
top
=
int
(
random
.
uniform
(
0
,
h
*
ratio
-
h
))
expand_img
[
top
:
top
+
h
,
left
:
left
+
w
]
=
img
results
[
'img'
]
=
expand_img
# expand bboxes
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
results
[
key
]
=
results
[
key
]
+
np
.
tile
(
(
left
,
top
),
2
).
astype
(
results
[
key
].
dtype
)
# expand masks
for
key
in
results
.
get
(
'mask_fields'
,
[]):
results
[
key
]
=
results
[
key
].
expand
(
int
(
h
*
ratio
),
int
(
w
*
ratio
),
top
,
left
)
# expand segs
for
key
in
results
.
get
(
'seg_fields'
,
[]):
gt_seg
=
results
[
key
]
expand_gt_seg
=
np
.
full
((
int
(
h
*
ratio
),
int
(
w
*
ratio
)),
self
.
seg_ignore_label
,
dtype
=
gt_seg
.
dtype
)
expand_gt_seg
[
top
:
top
+
h
,
left
:
left
+
w
]
=
gt_seg
results
[
key
]
=
expand_gt_seg
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(mean=
{
self
.
mean
}
, to_rgb=
{
self
.
to_rgb
}
, '
repr_str
+=
f
'ratio_range=
{
self
.
ratio_range
}
, '
repr_str
+=
f
'seg_ignore_label=
{
self
.
seg_ignore_label
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
MinIoURandomCrop
(
object
):
"""Random crop the image & bboxes, the cropped patches have minimum IoU
requirement with original image & bboxes, the IoU threshold is randomly
selected from min_ious.
Args:
min_ious (tuple): minimum IoU threshold for all intersections with
bounding boxes
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
where a >= min_crop_size).
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
Note:
The keys for bboxes, labels and masks should be paired. That is,
\
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and
\
`gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`.
"""
def
__init__
(
self
,
min_ious
=
(
0.1
,
0.3
,
0.5
,
0.7
,
0.9
),
min_crop_size
=
0.3
,
bbox_clip_border
=
True
):
# 1: return ori img
self
.
min_ious
=
min_ious
self
.
sample_mode
=
(
1
,
*
min_ious
,
0
)
self
.
min_crop_size
=
min_crop_size
self
.
bbox_clip_border
=
bbox_clip_border
self
.
bbox2label
=
{
'gt_bboxes'
:
'gt_labels'
,
'gt_bboxes_ignore'
:
'gt_labels_ignore'
}
self
.
bbox2mask
=
{
'gt_bboxes'
:
'gt_masks'
,
'gt_bboxes_ignore'
:
'gt_masks_ignore'
}
def
__call__
(
self
,
results
):
"""Call function to crop images and bounding boxes with minimum IoU
constraint.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images and bounding boxes cropped,
\
'img_shape' key is updated.
"""
if
'img_fields'
in
results
:
assert
results
[
'img_fields'
]
==
[
'img'
],
\
'Only single img_fields is allowed'
img
=
results
[
'img'
]
assert
'bbox_fields'
in
results
boxes
=
[
results
[
key
]
for
key
in
results
[
'bbox_fields'
]]
boxes
=
np
.
concatenate
(
boxes
,
0
)
h
,
w
,
c
=
img
.
shape
while
True
:
mode
=
random
.
choice
(
self
.
sample_mode
)
self
.
mode
=
mode
if
mode
==
1
:
return
results
min_iou
=
mode
for
i
in
range
(
50
):
new_w
=
random
.
uniform
(
self
.
min_crop_size
*
w
,
w
)
new_h
=
random
.
uniform
(
self
.
min_crop_size
*
h
,
h
)
# h / w in [0.5, 2]
if
new_h
/
new_w
<
0.5
or
new_h
/
new_w
>
2
:
continue
left
=
random
.
uniform
(
w
-
new_w
)
top
=
random
.
uniform
(
h
-
new_h
)
patch
=
np
.
array
(
(
int
(
left
),
int
(
top
),
int
(
left
+
new_w
),
int
(
top
+
new_h
)))
# Line or point crop is not allowed
if
patch
[
2
]
==
patch
[
0
]
or
patch
[
3
]
==
patch
[
1
]:
continue
overlaps
=
bbox_overlaps
(
patch
.
reshape
(
-
1
,
4
),
boxes
.
reshape
(
-
1
,
4
)).
reshape
(
-
1
)
if
len
(
overlaps
)
>
0
and
overlaps
.
min
()
<
min_iou
:
continue
# center of boxes should inside the crop img
# only adjust boxes and instance masks when the gt is not empty
if
len
(
overlaps
)
>
0
:
# adjust boxes
def
is_center_of_bboxes_in_patch
(
boxes
,
patch
):
center
=
(
boxes
[:,
:
2
]
+
boxes
[:,
2
:])
/
2
mask
=
((
center
[:,
0
]
>
patch
[
0
])
*
(
center
[:,
1
]
>
patch
[
1
])
*
(
center
[:,
0
]
<
patch
[
2
])
*
(
center
[:,
1
]
<
patch
[
3
]))
return
mask
mask
=
is_center_of_bboxes_in_patch
(
boxes
,
patch
)
if
not
mask
.
any
():
continue
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
boxes
=
results
[
key
].
copy
()
mask
=
is_center_of_bboxes_in_patch
(
boxes
,
patch
)
boxes
=
boxes
[
mask
]
if
self
.
bbox_clip_border
:
boxes
[:,
2
:]
=
boxes
[:,
2
:].
clip
(
max
=
patch
[
2
:])
boxes
[:,
:
2
]
=
boxes
[:,
:
2
].
clip
(
min
=
patch
[:
2
])
boxes
-=
np
.
tile
(
patch
[:
2
],
2
)
results
[
key
]
=
boxes
# labels
label_key
=
self
.
bbox2label
.
get
(
key
)
if
label_key
in
results
:
results
[
label_key
]
=
results
[
label_key
][
mask
]
# mask fields
mask_key
=
self
.
bbox2mask
.
get
(
key
)
if
mask_key
in
results
:
results
[
mask_key
]
=
results
[
mask_key
][
mask
.
nonzero
()[
0
]].
crop
(
patch
)
# adjust the img no matter whether the gt is empty before crop
img
=
img
[
patch
[
1
]:
patch
[
3
],
patch
[
0
]:
patch
[
2
]]
results
[
'img'
]
=
img
results
[
'img_shape'
]
=
img
.
shape
# seg fields
for
key
in
results
.
get
(
'seg_fields'
,
[]):
results
[
key
]
=
results
[
key
][
patch
[
1
]:
patch
[
3
],
patch
[
0
]:
patch
[
2
]]
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(min_ious=
{
self
.
min_ious
}
, '
repr_str
+=
f
'min_crop_size=
{
self
.
min_crop_size
}
, '
repr_str
+=
f
'bbox_clip_border=
{
self
.
bbox_clip_border
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
Corrupt
(
object
):
"""Corruption augmentation.
Corruption transforms implemented based on
`imagecorruptions <https://github.com/bethgelab/imagecorruptions>`_.
Args:
corruption (str): Corruption name.
severity (int, optional): The severity of corruption. Default: 1.
"""
def
__init__
(
self
,
corruption
,
severity
=
1
):
self
.
corruption
=
corruption
self
.
severity
=
severity
def
__call__
(
self
,
results
):
"""Call function to corrupt image.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images corrupted.
"""
if
corrupt
is
None
:
raise
RuntimeError
(
'imagecorruptions is not installed'
)
if
'img_fields'
in
results
:
assert
results
[
'img_fields'
]
==
[
'img'
],
\
'Only single img_fields is allowed'
results
[
'img'
]
=
corrupt
(
results
[
'img'
].
astype
(
np
.
uint8
),
corruption_name
=
self
.
corruption
,
severity
=
self
.
severity
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(corruption=
{
self
.
corruption
}
, '
repr_str
+=
f
'severity=
{
self
.
severity
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
Albu
(
object
):
"""Albumentation augmentation.
Adds custom transformations from Albumentations library.
Please, visit `https://albumentations.readthedocs.io`
to get more information.
An example of ``transforms`` is as followed:
.. code-block::
[
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
Args:
transforms (list[dict]): A list of albu transformations
bbox_params (dict): Bbox_params for albumentation `Compose`
keymap (dict): Contains {'input key':'albumentation-style key'}
skip_img_without_anno (bool): Whether to skip the image if no ann left
after aug
"""
def
__init__
(
self
,
transforms
,
bbox_params
=
None
,
keymap
=
None
,
update_pad_shape
=
False
,
skip_img_without_anno
=
False
):
if
Compose
is
None
:
raise
RuntimeError
(
'albumentations is not installed'
)
self
.
transforms
=
transforms
self
.
filter_lost_elements
=
False
self
.
update_pad_shape
=
update_pad_shape
self
.
skip_img_without_anno
=
skip_img_without_anno
# A simple workaround to remove masks without boxes
if
(
isinstance
(
bbox_params
,
dict
)
and
'label_fields'
in
bbox_params
and
'filter_lost_elements'
in
bbox_params
):
self
.
filter_lost_elements
=
True
self
.
origin_label_fields
=
bbox_params
[
'label_fields'
]
bbox_params
[
'label_fields'
]
=
[
'idx_mapper'
]
del
bbox_params
[
'filter_lost_elements'
]
self
.
bbox_params
=
(
self
.
albu_builder
(
bbox_params
)
if
bbox_params
else
None
)
self
.
aug
=
Compose
([
self
.
albu_builder
(
t
)
for
t
in
self
.
transforms
],
bbox_params
=
self
.
bbox_params
)
if
not
keymap
:
self
.
keymap_to_albu
=
{
'img'
:
'image'
,
'gt_masks'
:
'masks'
,
'gt_bboxes'
:
'bboxes'
}
else
:
self
.
keymap_to_albu
=
keymap
self
.
keymap_back
=
{
v
:
k
for
k
,
v
in
self
.
keymap_to_albu
.
items
()}
def
albu_builder
(
self
,
cfg
):
"""Import a module from albumentations.
It inherits some of :func:`build_from_cfg` logic.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
Returns:
obj: The constructed object.
"""
assert
isinstance
(
cfg
,
dict
)
and
'type'
in
cfg
args
=
cfg
.
copy
()
obj_type
=
args
.
pop
(
'type'
)
if
mmcv
.
is_str
(
obj_type
):
if
albumentations
is
None
:
raise
RuntimeError
(
'albumentations is not installed'
)
obj_cls
=
getattr
(
albumentations
,
obj_type
)
elif
inspect
.
isclass
(
obj_type
):
obj_cls
=
obj_type
else
:
raise
TypeError
(
f
'type must be a str or valid type, but got
{
type
(
obj_type
)
}
'
)
if
'transforms'
in
args
:
args
[
'transforms'
]
=
[
self
.
albu_builder
(
transform
)
for
transform
in
args
[
'transforms'
]
]
return
obj_cls
(
**
args
)
@
staticmethod
def
mapper
(
d
,
keymap
):
"""Dictionary mapper. Renames keys according to keymap provided.
Args:
d (dict): old dict
keymap (dict): {'old_key':'new_key'}
Returns:
dict: new dict.
"""
updated_dict
=
{}
for
k
,
v
in
zip
(
d
.
keys
(),
d
.
values
()):
new_k
=
keymap
.
get
(
k
,
k
)
updated_dict
[
new_k
]
=
d
[
k
]
return
updated_dict
def
__call__
(
self
,
results
):
# dict to albumentations format
results
=
self
.
mapper
(
results
,
self
.
keymap_to_albu
)
# TODO: add bbox_fields
if
'bboxes'
in
results
:
# to list of boxes
if
isinstance
(
results
[
'bboxes'
],
np
.
ndarray
):
results
[
'bboxes'
]
=
[
x
for
x
in
results
[
'bboxes'
]]
# add pseudo-field for filtration
if
self
.
filter_lost_elements
:
results
[
'idx_mapper'
]
=
np
.
arange
(
len
(
results
[
'bboxes'
]))
# TODO: Support mask structure in albu
if
'masks'
in
results
:
if
isinstance
(
results
[
'masks'
],
PolygonMasks
):
raise
NotImplementedError
(
'Albu only supports BitMap masks now'
)
ori_masks
=
results
[
'masks'
]
if
albumentations
.
__version__
<
'0.5'
:
results
[
'masks'
]
=
results
[
'masks'
].
masks
else
:
results
[
'masks'
]
=
[
mask
for
mask
in
results
[
'masks'
].
masks
]
results
=
self
.
aug
(
**
results
)
if
'bboxes'
in
results
:
if
isinstance
(
results
[
'bboxes'
],
list
):
results
[
'bboxes'
]
=
np
.
array
(
results
[
'bboxes'
],
dtype
=
np
.
float32
)
results
[
'bboxes'
]
=
results
[
'bboxes'
].
reshape
(
-
1
,
4
)
# filter label_fields
if
self
.
filter_lost_elements
:
for
label
in
self
.
origin_label_fields
:
results
[
label
]
=
np
.
array
(
[
results
[
label
][
i
]
for
i
in
results
[
'idx_mapper'
]])
if
'masks'
in
results
:
results
[
'masks'
]
=
np
.
array
(
[
results
[
'masks'
][
i
]
for
i
in
results
[
'idx_mapper'
]])
results
[
'masks'
]
=
ori_masks
.
__class__
(
results
[
'masks'
],
results
[
'image'
].
shape
[
0
],
results
[
'image'
].
shape
[
1
])
if
(
not
len
(
results
[
'idx_mapper'
])
and
self
.
skip_img_without_anno
):
return
None
if
'gt_labels'
in
results
:
if
isinstance
(
results
[
'gt_labels'
],
list
):
results
[
'gt_labels'
]
=
np
.
array
(
results
[
'gt_labels'
])
results
[
'gt_labels'
]
=
results
[
'gt_labels'
].
astype
(
np
.
int64
)
# back to the original format
results
=
self
.
mapper
(
results
,
self
.
keymap_back
)
# update final shape
if
self
.
update_pad_shape
:
results
[
'pad_shape'
]
=
results
[
'img'
].
shape
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
+
f
'(transforms=
{
self
.
transforms
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
RandomCenterCropPad
(
object
):
"""Random center crop and random around padding for CornerNet.
This operation generates randomly cropped image from the original image and
pads it simultaneously. Different from :class:`RandomCrop`, the output
shape may not equal to ``crop_size`` strictly. We choose a random value
from ``ratios`` and the output shape could be larger or smaller than
``crop_size``. The padding operation is also different from :class:`Pad`,
here we use around padding instead of right-bottom padding.
The relation between output image (padding image) and original image:
.. code:: text
output image
+----------------------------+
| padded area |
+------|----------------------------|----------+
| | cropped area | |
| | +---------------+ | |
| | | . center | | | original image
| | | range | | |
| | +---------------+ | |
+------|----------------------------|----------+
| padded area |
+----------------------------+
There are 5 main areas in the figure:
- output image: output image of this operation, also called padding
image in following instruction.
- original image: input image of this operation.
- padded area: non-intersect area of output image and original image.
- cropped area: the overlap of output image and original image.
- center range: a smaller area where random center chosen from.
center range is computed by ``border`` and original image's shape
to avoid our random center is too close to original image's border.
Also this operation act differently in train and test mode, the summary
pipeline is listed below.
Train pipeline:
1. Choose a ``random_ratio`` from ``ratios``, the shape of padding image
will be ``random_ratio * crop_size``.
2. Choose a ``random_center`` in center range.
3. Generate padding image with center matches the ``random_center``.
4. Initialize the padding image with pixel value equals to ``mean``.
5. Copy the cropped area to padding image.
6. Refine annotations.
Test pipeline:
1. Compute output shape according to ``test_pad_mode``.
2. Generate padding image with center matches the original image
center.
3. Initialize the padding image with pixel value equals to ``mean``.
4. Copy the ``cropped area`` to padding image.
Args:
crop_size (tuple | None): expected size after crop, final size will
computed according to ratio. Requires (h, w) in train mode, and
None in test mode.
ratios (tuple): random select a ratio from tuple and crop image to
(crop_size[0] * ratio) * (crop_size[1] * ratio).
Only available in train mode.
border (int): max distance from center select area to image border.
Only available in train mode.
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB.
test_mode (bool): whether involve random variables in transform.
In train mode, crop_size is fixed, center coords and ratio is
random selected from predefined lists. In test mode, crop_size
is image's original shape, center coords and ratio is fixed.
test_pad_mode (tuple): padding method and padding shape value, only
available in test mode. Default is using 'logical_or' with
127 as padding shape value.
- 'logical_or': final_shape = input_shape | padding_shape_value
- 'size_divisor': final_shape = int(
ceil(input_shape / padding_shape_value) * padding_shape_value)
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
"""
def
__init__
(
self
,
crop_size
=
None
,
ratios
=
(
0.9
,
1.0
,
1.1
),
border
=
128
,
mean
=
None
,
std
=
None
,
to_rgb
=
None
,
test_mode
=
False
,
test_pad_mode
=
(
'logical_or'
,
127
),
bbox_clip_border
=
True
):
if
test_mode
:
assert
crop_size
is
None
,
'crop_size must be None in test mode'
assert
ratios
is
None
,
'ratios must be None in test mode'
assert
border
is
None
,
'border must be None in test mode'
assert
isinstance
(
test_pad_mode
,
(
list
,
tuple
))
assert
test_pad_mode
[
0
]
in
[
'logical_or'
,
'size_divisor'
]
else
:
assert
isinstance
(
crop_size
,
(
list
,
tuple
))
assert
crop_size
[
0
]
>
0
and
crop_size
[
1
]
>
0
,
(
'crop_size must > 0 in train mode'
)
assert
isinstance
(
ratios
,
(
list
,
tuple
))
assert
test_pad_mode
is
None
,
(
'test_pad_mode must be None in train mode'
)
self
.
crop_size
=
crop_size
self
.
ratios
=
ratios
self
.
border
=
border
# We do not set default value to mean, std and to_rgb because these
# hyper-parameters are easy to forget but could affect the performance.
# Please use the same setting as Normalize for performance assurance.
assert
mean
is
not
None
and
std
is
not
None
and
to_rgb
is
not
None
self
.
to_rgb
=
to_rgb
self
.
input_mean
=
mean
self
.
input_std
=
std
if
to_rgb
:
self
.
mean
=
mean
[::
-
1
]
self
.
std
=
std
[::
-
1
]
else
:
self
.
mean
=
mean
self
.
std
=
std
self
.
test_mode
=
test_mode
self
.
test_pad_mode
=
test_pad_mode
self
.
bbox_clip_border
=
bbox_clip_border
def
_get_border
(
self
,
border
,
size
):
"""Get final border for the target size.
This function generates a ``final_border`` according to image's shape.
The area between ``final_border`` and ``size - final_border`` is the
``center range``. We randomly choose center from the ``center range``
to avoid our random center is too close to original image's border.
Also ``center range`` should be larger than 0.
Args:
border (int): The initial border, default is 128.
size (int): The width or height of original image.
Returns:
int: The final border.
"""
k
=
2
*
border
/
size
i
=
pow
(
2
,
np
.
ceil
(
np
.
log2
(
np
.
ceil
(
k
)))
+
(
k
==
int
(
k
)))
return
border
//
i
def
_filter_boxes
(
self
,
patch
,
boxes
):
"""Check whether the center of each box is in the patch.
Args:
patch (list[int]): The cropped area, [left, top, right, bottom].
boxes (numpy array, (N x 4)): Ground truth boxes.
Returns:
mask (numpy array, (N,)): Each box is inside or outside the patch.
"""
center
=
(
boxes
[:,
:
2
]
+
boxes
[:,
2
:])
/
2
mask
=
(
center
[:,
0
]
>
patch
[
0
])
*
(
center
[:,
1
]
>
patch
[
1
])
*
(
center
[:,
0
]
<
patch
[
2
])
*
(
center
[:,
1
]
<
patch
[
3
])
return
mask
def
_crop_image_and_paste
(
self
,
image
,
center
,
size
):
"""Crop image with a given center and size, then paste the cropped
image to a blank image with two centers align.
This function is equivalent to generating a blank image with ``size``
as its shape. Then cover it on the original image with two centers (
the center of blank image and the random center of original image)
aligned. The overlap area is paste from the original image and the
outside area is filled with ``mean pixel``.
Args:
image (np array, H x W x C): Original image.
center (list[int]): Target crop center coord.
size (list[int]): Target crop size. [target_h, target_w]
Returns:
cropped_img (np array, target_h x target_w x C): Cropped image.
border (np array, 4): The distance of four border of
``cropped_img`` to the original image area, [top, bottom,
left, right]
patch (list[int]): The cropped area, [left, top, right, bottom].
"""
center_y
,
center_x
=
center
target_h
,
target_w
=
size
img_h
,
img_w
,
img_c
=
image
.
shape
x0
=
max
(
0
,
center_x
-
target_w
//
2
)
x1
=
min
(
center_x
+
target_w
//
2
,
img_w
)
y0
=
max
(
0
,
center_y
-
target_h
//
2
)
y1
=
min
(
center_y
+
target_h
//
2
,
img_h
)
patch
=
np
.
array
((
int
(
x0
),
int
(
y0
),
int
(
x1
),
int
(
y1
)))
left
,
right
=
center_x
-
x0
,
x1
-
center_x
top
,
bottom
=
center_y
-
y0
,
y1
-
center_y
cropped_center_y
,
cropped_center_x
=
target_h
//
2
,
target_w
//
2
cropped_img
=
np
.
zeros
((
target_h
,
target_w
,
img_c
),
dtype
=
image
.
dtype
)
for
i
in
range
(
img_c
):
cropped_img
[:,
:,
i
]
+=
self
.
mean
[
i
]
y_slice
=
slice
(
cropped_center_y
-
top
,
cropped_center_y
+
bottom
)
x_slice
=
slice
(
cropped_center_x
-
left
,
cropped_center_x
+
right
)
cropped_img
[
y_slice
,
x_slice
,
:]
=
image
[
y0
:
y1
,
x0
:
x1
,
:]
border
=
np
.
array
([
cropped_center_y
-
top
,
cropped_center_y
+
bottom
,
cropped_center_x
-
left
,
cropped_center_x
+
right
],
dtype
=
np
.
float32
)
return
cropped_img
,
border
,
patch
def
_train_aug
(
self
,
results
):
"""Random crop and around padding the original image.
Args:
results (dict): Image infomations in the augment pipeline.
Returns:
results (dict): The updated dict.
"""
img
=
results
[
'img'
]
h
,
w
,
c
=
img
.
shape
boxes
=
results
[
'gt_bboxes'
]
while
True
:
scale
=
random
.
choice
(
self
.
ratios
)
new_h
=
int
(
self
.
crop_size
[
0
]
*
scale
)
new_w
=
int
(
self
.
crop_size
[
1
]
*
scale
)
h_border
=
self
.
_get_border
(
self
.
border
,
h
)
w_border
=
self
.
_get_border
(
self
.
border
,
w
)
for
i
in
range
(
50
):
center_x
=
random
.
randint
(
low
=
w_border
,
high
=
w
-
w_border
)
center_y
=
random
.
randint
(
low
=
h_border
,
high
=
h
-
h_border
)
cropped_img
,
border
,
patch
=
self
.
_crop_image_and_paste
(
img
,
[
center_y
,
center_x
],
[
new_h
,
new_w
])
mask
=
self
.
_filter_boxes
(
patch
,
boxes
)
# if image do not have valid bbox, any crop patch is valid.
if
not
mask
.
any
()
and
len
(
boxes
)
>
0
:
continue
results
[
'img'
]
=
cropped_img
results
[
'img_shape'
]
=
cropped_img
.
shape
results
[
'pad_shape'
]
=
cropped_img
.
shape
x0
,
y0
,
x1
,
y1
=
patch
left_w
,
top_h
=
center_x
-
x0
,
center_y
-
y0
cropped_center_x
,
cropped_center_y
=
new_w
//
2
,
new_h
//
2
# crop bboxes accordingly and clip to the image boundary
for
key
in
results
.
get
(
'bbox_fields'
,
[]):
mask
=
self
.
_filter_boxes
(
patch
,
results
[
key
])
bboxes
=
results
[
key
][
mask
]
bboxes
[:,
0
:
4
:
2
]
+=
cropped_center_x
-
left_w
-
x0
bboxes
[:,
1
:
4
:
2
]
+=
cropped_center_y
-
top_h
-
y0
if
self
.
bbox_clip_border
:
bboxes
[:,
0
:
4
:
2
]
=
np
.
clip
(
bboxes
[:,
0
:
4
:
2
],
0
,
new_w
)
bboxes
[:,
1
:
4
:
2
]
=
np
.
clip
(
bboxes
[:,
1
:
4
:
2
],
0
,
new_h
)
keep
=
(
bboxes
[:,
2
]
>
bboxes
[:,
0
])
&
(
bboxes
[:,
3
]
>
bboxes
[:,
1
])
bboxes
=
bboxes
[
keep
]
results
[
key
]
=
bboxes
if
key
in
[
'gt_bboxes'
]:
if
'gt_labels'
in
results
:
labels
=
results
[
'gt_labels'
][
mask
]
labels
=
labels
[
keep
]
results
[
'gt_labels'
]
=
labels
if
'gt_masks'
in
results
:
raise
NotImplementedError
(
'RandomCenterCropPad only supports bbox.'
)
# crop semantic seg
for
key
in
results
.
get
(
'seg_fields'
,
[]):
raise
NotImplementedError
(
'RandomCenterCropPad only supports bbox.'
)
return
results
def
_test_aug
(
self
,
results
):
"""Around padding the original image without cropping.
The padding mode and value are from ``test_pad_mode``.
Args:
results (dict): Image infomations in the augment pipeline.
Returns:
results (dict): The updated dict.
"""
img
=
results
[
'img'
]
h
,
w
,
c
=
img
.
shape
results
[
'img_shape'
]
=
img
.
shape
if
self
.
test_pad_mode
[
0
]
in
[
'logical_or'
]:
target_h
=
h
|
self
.
test_pad_mode
[
1
]
target_w
=
w
|
self
.
test_pad_mode
[
1
]
elif
self
.
test_pad_mode
[
0
]
in
[
'size_divisor'
]:
divisor
=
self
.
test_pad_mode
[
1
]
target_h
=
int
(
np
.
ceil
(
h
/
divisor
))
*
divisor
target_w
=
int
(
np
.
ceil
(
w
/
divisor
))
*
divisor
else
:
raise
NotImplementedError
(
'RandomCenterCropPad only support two testing pad mode:'
'logical-or and size_divisor.'
)
cropped_img
,
border
,
_
=
self
.
_crop_image_and_paste
(
img
,
[
h
//
2
,
w
//
2
],
[
target_h
,
target_w
])
results
[
'img'
]
=
cropped_img
results
[
'pad_shape'
]
=
cropped_img
.
shape
results
[
'border'
]
=
border
return
results
def
__call__
(
self
,
results
):
img
=
results
[
'img'
]
assert
img
.
dtype
==
np
.
float32
,
(
'RandomCenterCropPad needs the input image of dtype np.float32,'
' please set "to_float32=True" in "LoadImageFromFile" pipeline'
)
h
,
w
,
c
=
img
.
shape
assert
c
==
len
(
self
.
mean
)
if
self
.
test_mode
:
return
self
.
_test_aug
(
results
)
else
:
return
self
.
_train_aug
(
results
)
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(crop_size=
{
self
.
crop_size
}
, '
repr_str
+=
f
'ratios=
{
self
.
ratios
}
, '
repr_str
+=
f
'border=
{
self
.
border
}
, '
repr_str
+=
f
'mean=
{
self
.
input_mean
}
, '
repr_str
+=
f
'std=
{
self
.
input_std
}
, '
repr_str
+=
f
'to_rgb=
{
self
.
to_rgb
}
, '
repr_str
+=
f
'test_mode=
{
self
.
test_mode
}
, '
repr_str
+=
f
'test_pad_mode=
{
self
.
test_pad_mode
}
, '
repr_str
+=
f
'bbox_clip_border=
{
self
.
bbox_clip_border
}
)'
return
repr_str
@
PIPELINES
.
register_module
()
class
CutOut
(
object
):
"""CutOut operation.
Randomly drop some regions of image used in
`Cutout <https://arxiv.org/abs/1708.04552>`_.
Args:
n_holes (int | tuple[int, int]): Number of regions to be dropped.
If it is given as a list, number of holes will be randomly
selected from the closed interval [`n_holes[0]`, `n_holes[1]`].
cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate
shape of dropped regions. It can be `tuple[int, int]` to use a
fixed cutout shape, or `list[tuple[int, int]]` to randomly choose
shape from the list.
cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The
candidate ratio of dropped regions. It can be `tuple[float, float]`
to use a fixed ratio or `list[tuple[float, float]]` to randomly
choose ratio from the list. Please note that `cutout_shape`
and `cutout_ratio` cannot be both given at the same time.
fill_in (tuple[float, float, float] | tuple[int, int, int]): The value
of pixel to fill in the dropped regions. Default: (0, 0, 0).
"""
def
__init__
(
self
,
n_holes
,
cutout_shape
=
None
,
cutout_ratio
=
None
,
fill_in
=
(
0
,
0
,
0
)):
assert
(
cutout_shape
is
None
)
^
(
cutout_ratio
is
None
),
\
'Either cutout_shape or cutout_ratio should be specified.'
assert
(
isinstance
(
cutout_shape
,
(
list
,
tuple
))
or
isinstance
(
cutout_ratio
,
(
list
,
tuple
)))
if
isinstance
(
n_holes
,
tuple
):
assert
len
(
n_holes
)
==
2
and
0
<=
n_holes
[
0
]
<
n_holes
[
1
]
else
:
n_holes
=
(
n_holes
,
n_holes
)
self
.
n_holes
=
n_holes
self
.
fill_in
=
fill_in
self
.
with_ratio
=
cutout_ratio
is
not
None
self
.
candidates
=
cutout_ratio
if
self
.
with_ratio
else
cutout_shape
if
not
isinstance
(
self
.
candidates
,
list
):
self
.
candidates
=
[
self
.
candidates
]
def
__call__
(
self
,
results
):
"""Call function to drop some regions of image."""
h
,
w
,
c
=
results
[
'img'
].
shape
n_holes
=
np
.
random
.
randint
(
self
.
n_holes
[
0
],
self
.
n_holes
[
1
]
+
1
)
for
_
in
range
(
n_holes
):
x1
=
np
.
random
.
randint
(
0
,
w
)
y1
=
np
.
random
.
randint
(
0
,
h
)
index
=
np
.
random
.
randint
(
0
,
len
(
self
.
candidates
))
if
not
self
.
with_ratio
:
cutout_w
,
cutout_h
=
self
.
candidates
[
index
]
else
:
cutout_w
=
int
(
self
.
candidates
[
index
][
0
]
*
w
)
cutout_h
=
int
(
self
.
candidates
[
index
][
1
]
*
h
)
x2
=
np
.
clip
(
x1
+
cutout_w
,
0
,
w
)
y2
=
np
.
clip
(
y1
+
cutout_h
,
0
,
h
)
results
[
'img'
][
y1
:
y2
,
x1
:
x2
,
:]
=
self
.
fill_in
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(n_holes=
{
self
.
n_holes
}
, '
repr_str
+=
(
f
'cutout_ratio=
{
self
.
candidates
}
, '
if
self
.
with_ratio
else
f
'cutout_shape=
{
self
.
candidates
}
, '
)
repr_str
+=
f
'fill_in=
{
self
.
fill_in
}
)'
return
repr_str
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/samplers/__init__.py
0 → 100644
View file @
142dcf29
from
.distributed_sampler
import
DistributedSampler
from
.group_sampler
import
DistributedGroupSampler
,
GroupSampler
__all__
=
[
'DistributedSampler'
,
'DistributedGroupSampler'
,
'GroupSampler'
]
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/samplers/distributed_sampler.py
0 → 100644
View file @
142dcf29
import
math
import
torch
from
torch.utils.data
import
DistributedSampler
as
_DistributedSampler
class
DistributedSampler
(
_DistributedSampler
):
def
__init__
(
self
,
dataset
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
True
):
super
().
__init__
(
dataset
,
num_replicas
=
num_replicas
,
rank
=
rank
)
self
.
shuffle
=
shuffle
def
__iter__
(
self
):
# deterministically shuffle based on epoch
if
self
.
shuffle
:
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
else
:
indices
=
torch
.
arange
(
len
(
self
.
dataset
)).
tolist
()
# add extra samples to make it evenly divisible
# in case that indices is shorter than half of total_size
indices
=
(
indices
*
math
.
ceil
(
self
.
total_size
/
len
(
indices
)))[:
self
.
total_size
]
assert
len
(
indices
)
==
self
.
total_size
# subsample
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/samplers/group_sampler.py
0 → 100644
View file @
142dcf29
from
__future__
import
division
import
math
import
numpy
as
np
import
torch
from
mmcv.runner
import
get_dist_info
from
torch.utils.data
import
Sampler
class
GroupSampler
(
Sampler
):
def
__init__
(
self
,
dataset
,
samples_per_gpu
=
1
):
assert
hasattr
(
dataset
,
'flag'
)
self
.
dataset
=
dataset
self
.
samples_per_gpu
=
samples_per_gpu
self
.
flag
=
dataset
.
flag
.
astype
(
np
.
int64
)
self
.
group_sizes
=
np
.
bincount
(
self
.
flag
)
self
.
num_samples
=
0
for
i
,
size
in
enumerate
(
self
.
group_sizes
):
self
.
num_samples
+=
int
(
np
.
ceil
(
size
/
self
.
samples_per_gpu
))
*
self
.
samples_per_gpu
def
__iter__
(
self
):
indices
=
[]
for
i
,
size
in
enumerate
(
self
.
group_sizes
):
if
size
==
0
:
continue
indice
=
np
.
where
(
self
.
flag
==
i
)[
0
]
assert
len
(
indice
)
==
size
np
.
random
.
shuffle
(
indice
)
num_extra
=
int
(
np
.
ceil
(
size
/
self
.
samples_per_gpu
)
)
*
self
.
samples_per_gpu
-
len
(
indice
)
indice
=
np
.
concatenate
(
[
indice
,
np
.
random
.
choice
(
indice
,
num_extra
)])
indices
.
append
(
indice
)
indices
=
np
.
concatenate
(
indices
)
indices
=
[
indices
[
i
*
self
.
samples_per_gpu
:(
i
+
1
)
*
self
.
samples_per_gpu
]
for
i
in
np
.
random
.
permutation
(
range
(
len
(
indices
)
//
self
.
samples_per_gpu
))
]
indices
=
np
.
concatenate
(
indices
)
indices
=
indices
.
astype
(
np
.
int64
).
tolist
()
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
__len__
(
self
):
return
self
.
num_samples
class
DistributedGroupSampler
(
Sampler
):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
"""
def
__init__
(
self
,
dataset
,
samples_per_gpu
=
1
,
num_replicas
=
None
,
rank
=
None
):
_rank
,
_num_replicas
=
get_dist_info
()
if
num_replicas
is
None
:
num_replicas
=
_num_replicas
if
rank
is
None
:
rank
=
_rank
self
.
dataset
=
dataset
self
.
samples_per_gpu
=
samples_per_gpu
self
.
num_replicas
=
num_replicas
self
.
rank
=
rank
self
.
epoch
=
0
assert
hasattr
(
self
.
dataset
,
'flag'
)
self
.
flag
=
self
.
dataset
.
flag
self
.
group_sizes
=
np
.
bincount
(
self
.
flag
)
self
.
num_samples
=
0
for
i
,
j
in
enumerate
(
self
.
group_sizes
):
self
.
num_samples
+=
int
(
math
.
ceil
(
self
.
group_sizes
[
i
]
*
1.0
/
self
.
samples_per_gpu
/
self
.
num_replicas
))
*
self
.
samples_per_gpu
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
def
__iter__
(
self
):
# deterministically shuffle based on epoch
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
indices
=
[]
for
i
,
size
in
enumerate
(
self
.
group_sizes
):
if
size
>
0
:
indice
=
np
.
where
(
self
.
flag
==
i
)[
0
]
assert
len
(
indice
)
==
size
# add .numpy() to avoid bug when selecting indice in parrots.
# TODO: check whether torch.randperm() can be replaced by
# numpy.random.permutation().
indice
=
indice
[
list
(
torch
.
randperm
(
int
(
size
),
generator
=
g
).
numpy
())].
tolist
()
extra
=
int
(
math
.
ceil
(
size
*
1.0
/
self
.
samples_per_gpu
/
self
.
num_replicas
)
)
*
self
.
samples_per_gpu
*
self
.
num_replicas
-
len
(
indice
)
# pad indice
tmp
=
indice
.
copy
()
for
_
in
range
(
extra
//
size
):
indice
.
extend
(
tmp
)
indice
.
extend
(
tmp
[:
extra
%
size
])
indices
.
extend
(
indice
)
assert
len
(
indices
)
==
self
.
total_size
indices
=
[
indices
[
j
]
for
i
in
list
(
torch
.
randperm
(
len
(
indices
)
//
self
.
samples_per_gpu
,
generator
=
g
))
for
j
in
range
(
i
*
self
.
samples_per_gpu
,
(
i
+
1
)
*
self
.
samples_per_gpu
)
]
# subsample
offset
=
self
.
num_samples
*
self
.
rank
indices
=
indices
[
offset
:
offset
+
self
.
num_samples
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
__len__
(
self
):
return
self
.
num_samples
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/utils.py
0 → 100644
View file @
142dcf29
import
copy
import
warnings
def
replace_ImageToTensor
(
pipelines
):
"""Replace the ImageToTensor transform in a data pipeline to
DefaultFormatBundle, which is normally useful in batch inference.
Args:
pipelines (list[dict]): Data pipeline configs.
Returns:
list: The new pipeline list with all ImageToTensor replaced by
DefaultFormatBundle.
Examples:
>>> pipelines = [
... dict(type='LoadImageFromFile'),
... dict(
... type='MultiScaleFlipAug',
... img_scale=(1333, 800),
... flip=False,
... transforms=[
... dict(type='Resize', keep_ratio=True),
... dict(type='RandomFlip'),
... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
... dict(type='Pad', size_divisor=32),
... dict(type='ImageToTensor', keys=['img']),
... dict(type='Collect', keys=['img']),
... ])
... ]
>>> expected_pipelines = [
... dict(type='LoadImageFromFile'),
... dict(
... type='MultiScaleFlipAug',
... img_scale=(1333, 800),
... flip=False,
... transforms=[
... dict(type='Resize', keep_ratio=True),
... dict(type='RandomFlip'),
... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
... dict(type='Pad', size_divisor=32),
... dict(type='DefaultFormatBundle'),
... dict(type='Collect', keys=['img']),
... ])
... ]
>>> assert expected_pipelines == replace_ImageToTensor(pipelines)
"""
pipelines
=
copy
.
deepcopy
(
pipelines
)
for
i
,
pipeline
in
enumerate
(
pipelines
):
if
pipeline
[
'type'
]
==
'MultiScaleFlipAug'
:
assert
'transforms'
in
pipeline
pipeline
[
'transforms'
]
=
replace_ImageToTensor
(
pipeline
[
'transforms'
])
elif
pipeline
[
'type'
]
==
'ImageToTensor'
:
warnings
.
warn
(
'"ImageToTensor" pipeline is replaced by '
'"DefaultFormatBundle" for batch inference. It is '
'recommended to manually replace it in the test '
'data pipeline in your config file.'
,
UserWarning
)
pipelines
[
i
]
=
{
'type'
:
'DefaultFormatBundle'
}
return
pipelines
def
get_loading_pipeline
(
pipeline
):
"""Only keep loading image and annotations related configuration.
Args:
pipeline (list[dict]): Data pipeline configs.
Returns:
list[dict]: The new pipeline list with only keep
loading image and annotations related configuration.
Examples:
>>> pipelines = [
... dict(type='LoadImageFromFile'),
... dict(type='LoadAnnotations', with_bbox=True),
... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
... dict(type='RandomFlip', flip_ratio=0.5),
... dict(type='Normalize', **img_norm_cfg),
... dict(type='Pad', size_divisor=32),
... dict(type='DefaultFormatBundle'),
... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
... ]
>>> expected_pipelines = [
... dict(type='LoadImageFromFile'),
... dict(type='LoadAnnotations', with_bbox=True)
... ]
>>> assert expected_pipelines ==
\
... get_loading_pipeline(pipelines)
"""
loading_pipeline_cfg
=
[]
for
cfg
in
pipeline
:
if
cfg
[
'type'
].
startswith
(
'Load'
):
loading_pipeline_cfg
.
append
(
cfg
)
assert
len
(
loading_pipeline_cfg
)
==
2
,
\
'The data pipeline in your config file must include '
\
'loading image and annotations related pipeline.'
return
loading_pipeline_cfg
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/voc.py
0 → 100644
View file @
142dcf29
from
collections
import
OrderedDict
from
mmcv.utils
import
print_log
from
mmdet.core
import
eval_map
,
eval_recalls
from
.builder
import
DATASETS
from
.xml_style
import
XMLDataset
@
DATASETS
.
register_module
()
class
VOCDataset
(
XMLDataset
):
CLASSES
=
(
'aeroplane'
,
'bicycle'
,
'bird'
,
'boat'
,
'bottle'
,
'bus'
,
'car'
,
'cat'
,
'chair'
,
'cow'
,
'diningtable'
,
'dog'
,
'horse'
,
'motorbike'
,
'person'
,
'pottedplant'
,
'sheep'
,
'sofa'
,
'train'
,
'tvmonitor'
)
def
__init__
(
self
,
**
kwargs
):
super
(
VOCDataset
,
self
).
__init__
(
**
kwargs
)
if
'VOC2007'
in
self
.
img_prefix
:
self
.
year
=
2007
elif
'VOC2012'
in
self
.
img_prefix
:
self
.
year
=
2012
else
:
raise
ValueError
(
'Cannot infer dataset year from img_prefix'
)
def
evaluate
(
self
,
results
,
metric
=
'mAP'
,
logger
=
None
,
proposal_nums
=
(
100
,
300
,
1000
),
iou_thr
=
0.5
,
scale_ranges
=
None
):
"""Evaluate in VOC protocol.
Args:
results (list[list | tuple]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. Options are
'mAP', 'recall'.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.
Default: (100, 300, 1000).
iou_thr (float | list[float]): IoU threshold. Default: 0.5.
scale_ranges (list[tuple], optional): Scale ranges for evaluating
mAP. If not specified, all bounding boxes would be included in
evaluation. Default: None.
Returns:
dict[str, float]: AP/recall metrics.
"""
if
not
isinstance
(
metric
,
str
):
assert
len
(
metric
)
==
1
metric
=
metric
[
0
]
allowed_metrics
=
[
'mAP'
,
'recall'
]
if
metric
not
in
allowed_metrics
:
raise
KeyError
(
f
'metric
{
metric
}
is not supported'
)
annotations
=
[
self
.
get_ann_info
(
i
)
for
i
in
range
(
len
(
self
))]
eval_results
=
OrderedDict
()
iou_thrs
=
[
iou_thr
]
if
isinstance
(
iou_thr
,
float
)
else
iou_thr
if
metric
==
'mAP'
:
assert
isinstance
(
iou_thrs
,
list
)
if
self
.
year
==
2007
:
ds_name
=
'voc07'
else
:
ds_name
=
self
.
CLASSES
mean_aps
=
[]
for
iou_thr
in
iou_thrs
:
print_log
(
f
'
\n
{
"-"
*
15
}
iou_thr:
{
iou_thr
}{
"-"
*
15
}
'
)
mean_ap
,
_
=
eval_map
(
results
,
annotations
,
scale_ranges
=
None
,
iou_thr
=
iou_thr
,
dataset
=
ds_name
,
logger
=
logger
)
mean_aps
.
append
(
mean_ap
)
eval_results
[
f
'AP
{
int
(
iou_thr
*
100
):
02
d
}
'
]
=
round
(
mean_ap
,
3
)
eval_results
[
'mAP'
]
=
sum
(
mean_aps
)
/
len
(
mean_aps
)
elif
metric
==
'recall'
:
gt_bboxes
=
[
ann
[
'bboxes'
]
for
ann
in
annotations
]
recalls
=
eval_recalls
(
gt_bboxes
,
results
,
proposal_nums
,
iou_thr
,
logger
=
logger
)
for
i
,
num
in
enumerate
(
proposal_nums
):
for
j
,
iou
in
enumerate
(
iou_thr
):
eval_results
[
f
'recall@
{
num
}
@
{
iou
}
'
]
=
recalls
[
i
,
j
]
if
recalls
.
shape
[
1
]
>
1
:
ar
=
recalls
.
mean
(
axis
=
1
)
for
i
,
num
in
enumerate
(
proposal_nums
):
eval_results
[
f
'AR@
{
num
}
'
]
=
ar
[
i
]
return
eval_results
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/wider_face.py
0 → 100644
View file @
142dcf29
import
os.path
as
osp
import
xml.etree.ElementTree
as
ET
import
mmcv
from
.builder
import
DATASETS
from
.xml_style
import
XMLDataset
@
DATASETS
.
register_module
()
class
WIDERFaceDataset
(
XMLDataset
):
"""Reader for the WIDER Face dataset in PASCAL VOC format.
Conversion scripts can be found in
https://github.com/sovrasov/wider-face-pascal-voc-annotations
"""
CLASSES
=
(
'face'
,
)
def
__init__
(
self
,
**
kwargs
):
super
(
WIDERFaceDataset
,
self
).
__init__
(
**
kwargs
)
def
load_annotations
(
self
,
ann_file
):
"""Load annotation from WIDERFace XML style annotation file.
Args:
ann_file (str): Path of XML file.
Returns:
list[dict]: Annotation info from XML file.
"""
data_infos
=
[]
img_ids
=
mmcv
.
list_from_file
(
ann_file
)
for
img_id
in
img_ids
:
filename
=
f
'
{
img_id
}
.jpg'
xml_path
=
osp
.
join
(
self
.
img_prefix
,
'Annotations'
,
f
'
{
img_id
}
.xml'
)
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
size
=
root
.
find
(
'size'
)
width
=
int
(
size
.
find
(
'width'
).
text
)
height
=
int
(
size
.
find
(
'height'
).
text
)
folder
=
root
.
find
(
'folder'
).
text
data_infos
.
append
(
dict
(
id
=
img_id
,
filename
=
osp
.
join
(
folder
,
filename
),
width
=
width
,
height
=
height
))
return
data_infos
PyTorch/NLP/Conformer-main/mmdetection/mmdet/datasets/xml_style.py
0 → 100644
View file @
142dcf29
import
os.path
as
osp
import
xml.etree.ElementTree
as
ET
import
mmcv
import
numpy
as
np
from
PIL
import
Image
from
.builder
import
DATASETS
from
.custom
import
CustomDataset
@
DATASETS
.
register_module
()
class
XMLDataset
(
CustomDataset
):
"""XML dataset for detection.
Args:
min_size (int | float, optional): The minimum size of bounding
boxes in the images. If the size of a bounding box is less than
``min_size``, it would be add to ignored field.
"""
def
__init__
(
self
,
min_size
=
None
,
**
kwargs
):
super
(
XMLDataset
,
self
).
__init__
(
**
kwargs
)
self
.
cat2label
=
{
cat
:
i
for
i
,
cat
in
enumerate
(
self
.
CLASSES
)}
self
.
min_size
=
min_size
def
load_annotations
(
self
,
ann_file
):
"""Load annotation from XML style ann_file.
Args:
ann_file (str): Path of XML file.
Returns:
list[dict]: Annotation info from XML file.
"""
data_infos
=
[]
img_ids
=
mmcv
.
list_from_file
(
ann_file
)
for
img_id
in
img_ids
:
filename
=
f
'JPEGImages/
{
img_id
}
.jpg'
xml_path
=
osp
.
join
(
self
.
img_prefix
,
'Annotations'
,
f
'
{
img_id
}
.xml'
)
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
size
=
root
.
find
(
'size'
)
width
=
0
height
=
0
if
size
is
not
None
:
width
=
int
(
size
.
find
(
'width'
).
text
)
height
=
int
(
size
.
find
(
'height'
).
text
)
else
:
img_path
=
osp
.
join
(
self
.
img_prefix
,
'JPEGImages'
,
'{}.jpg'
.
format
(
img_id
))
img
=
Image
.
open
(
img_path
)
width
,
height
=
img
.
size
data_infos
.
append
(
dict
(
id
=
img_id
,
filename
=
filename
,
width
=
width
,
height
=
height
))
return
data_infos
def
_filter_imgs
(
self
,
min_size
=
32
):
"""Filter images too small or without annotation."""
valid_inds
=
[]
for
i
,
img_info
in
enumerate
(
self
.
data_infos
):
if
min
(
img_info
[
'width'
],
img_info
[
'height'
])
<
min_size
:
continue
if
self
.
filter_empty_gt
:
img_id
=
img_info
[
'id'
]
xml_path
=
osp
.
join
(
self
.
img_prefix
,
'Annotations'
,
f
'
{
img_id
}
.xml'
)
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
for
obj
in
root
.
findall
(
'object'
):
name
=
obj
.
find
(
'name'
).
text
if
name
in
self
.
CLASSES
:
valid_inds
.
append
(
i
)
break
else
:
valid_inds
.
append
(
i
)
return
valid_inds
def
get_ann_info
(
self
,
idx
):
"""Get annotation from XML file by index.
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
img_id
=
self
.
data_infos
[
idx
][
'id'
]
xml_path
=
osp
.
join
(
self
.
img_prefix
,
'Annotations'
,
f
'
{
img_id
}
.xml'
)
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
bboxes
=
[]
labels
=
[]
bboxes_ignore
=
[]
labels_ignore
=
[]
for
obj
in
root
.
findall
(
'object'
):
name
=
obj
.
find
(
'name'
).
text
if
name
not
in
self
.
CLASSES
:
continue
label
=
self
.
cat2label
[
name
]
difficult
=
int
(
obj
.
find
(
'difficult'
).
text
)
bnd_box
=
obj
.
find
(
'bndbox'
)
# TODO: check whether it is necessary to use int
# Coordinates may be float type
bbox
=
[
int
(
float
(
bnd_box
.
find
(
'xmin'
).
text
)),
int
(
float
(
bnd_box
.
find
(
'ymin'
).
text
)),
int
(
float
(
bnd_box
.
find
(
'xmax'
).
text
)),
int
(
float
(
bnd_box
.
find
(
'ymax'
).
text
))
]
ignore
=
False
if
self
.
min_size
:
assert
not
self
.
test_mode
w
=
bbox
[
2
]
-
bbox
[
0
]
h
=
bbox
[
3
]
-
bbox
[
1
]
if
w
<
self
.
min_size
or
h
<
self
.
min_size
:
ignore
=
True
if
difficult
or
ignore
:
bboxes_ignore
.
append
(
bbox
)
labels_ignore
.
append
(
label
)
else
:
bboxes
.
append
(
bbox
)
labels
.
append
(
label
)
if
not
bboxes
:
bboxes
=
np
.
zeros
((
0
,
4
))
labels
=
np
.
zeros
((
0
,
))
else
:
bboxes
=
np
.
array
(
bboxes
,
ndmin
=
2
)
-
1
labels
=
np
.
array
(
labels
)
if
not
bboxes_ignore
:
bboxes_ignore
=
np
.
zeros
((
0
,
4
))
labels_ignore
=
np
.
zeros
((
0
,
))
else
:
bboxes_ignore
=
np
.
array
(
bboxes_ignore
,
ndmin
=
2
)
-
1
labels_ignore
=
np
.
array
(
labels_ignore
)
ann
=
dict
(
bboxes
=
bboxes
.
astype
(
np
.
float32
),
labels
=
labels
.
astype
(
np
.
int64
),
bboxes_ignore
=
bboxes_ignore
.
astype
(
np
.
float32
),
labels_ignore
=
labels_ignore
.
astype
(
np
.
int64
))
return
ann
def
get_cat_ids
(
self
,
idx
):
"""Get category ids in XML file by index.
Args:
idx (int): Index of data.
Returns:
list[int]: All categories in the image of specified index.
"""
cat_ids
=
[]
img_id
=
self
.
data_infos
[
idx
][
'id'
]
xml_path
=
osp
.
join
(
self
.
img_prefix
,
'Annotations'
,
f
'
{
img_id
}
.xml'
)
tree
=
ET
.
parse
(
xml_path
)
root
=
tree
.
getroot
()
for
obj
in
root
.
findall
(
'object'
):
name
=
obj
.
find
(
'name'
).
text
if
name
not
in
self
.
CLASSES
:
continue
label
=
self
.
cat2label
[
name
]
cat_ids
.
append
(
label
)
return
cat_ids
PyTorch/NLP/Conformer-main/mmdetection/mmdet/models/__init__.py
0 → 100644
View file @
142dcf29
from
.backbones
import
*
# noqa: F401,F403
from
.builder
import
(
BACKBONES
,
DETECTORS
,
HEADS
,
LOSSES
,
NECKS
,
ROI_EXTRACTORS
,
SHARED_HEADS
,
build_backbone
,
build_detector
,
build_head
,
build_loss
,
build_neck
,
build_roi_extractor
,
build_shared_head
)
from
.dense_heads
import
*
# noqa: F401,F403
from
.detectors
import
*
# noqa: F401,F403
from
.losses
import
*
# noqa: F401,F403
from
.necks
import
*
# noqa: F401,F403
from
.roi_heads
import
*
# noqa: F401,F403
__all__
=
[
'BACKBONES'
,
'NECKS'
,
'ROI_EXTRACTORS'
,
'SHARED_HEADS'
,
'HEADS'
,
'LOSSES'
,
'DETECTORS'
,
'build_backbone'
,
'build_neck'
,
'build_roi_extractor'
,
'build_shared_head'
,
'build_head'
,
'build_loss'
,
'build_detector'
]
PyTorch/NLP/Conformer-main/mmdetection/mmdet/models/backbones/Conformer.py
0 → 100644
View file @
142dcf29
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
import
math
import
warnings
from
mmdet.utils
import
get_root_logger
from
mmcv.runner
import
load_checkpoint
from
..builder
import
BACKBONES
_DEFAULT_SCALE_CLAMP
=
math
.
log
(
100000.0
/
16
)
import
pdb
def
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def
norm_cdf
(
x
):
# Computes standard normal cumulative distribution function
return
(
1.
+
math
.
erf
(
x
/
math
.
sqrt
(
2.
)))
/
2.
if
(
mean
<
a
-
2
*
std
)
or
(
mean
>
b
+
2
*
std
):
warnings
.
warn
(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
,
stacklevel
=
2
)
with
torch
.
no_grad
():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l
=
norm_cdf
((
a
-
mean
)
/
std
)
u
=
norm_cdf
((
b
-
mean
)
/
std
)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor
.
uniform_
(
2
*
l
-
1
,
2
*
u
-
1
)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor
.
erfinv_
()
# Transform to proper mean, std
tensor
.
mul_
(
std
*
math
.
sqrt
(
2.
))
tensor
.
add_
(
mean
)
# Clamp to ensure it's in the proper range
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
return
tensor
def
trunc_normal_
(
tensor
,
mean
=
0.
,
std
=
1.
,
a
=-
2.
,
b
=
2.
):
# type: (Tensor, float, float, float, float) -> Tensor
r
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
)
class
DropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
self
.
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
def
drop_path
(
self
,
x
,
drop_prob
:
float
=
0.
,
training
:
bool
=
False
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if
drop_prob
==
0.
or
not
training
:
return
x
keep_prob
=
1
-
drop_prob
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
# work with diff dim tensors, not just 2D ConvNets
random_tensor
=
keep_prob
+
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
class
ConvBlock
(
nn
.
Module
):
def
__init__
(
self
,
inplanes
,
outplanes
,
stride
=
1
,
res_conv
=
False
,
act_layer
=
nn
.
ReLU
,
groups
=
1
,
norm_layer
=
partial
(
nn
.
BatchNorm2d
,
eps
=
1e-6
),
drop_block
=
None
,
drop_path
=
None
):
super
(
ConvBlock
,
self
).
__init__
()
expansion
=
4
med_planes
=
outplanes
//
expansion
self
.
conv1
=
nn
.
Conv2d
(
inplanes
,
med_planes
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
bn1
=
norm_layer
(
med_planes
)
self
.
act1
=
act_layer
(
inplace
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
med_planes
,
med_planes
,
kernel_size
=
3
,
stride
=
stride
,
groups
=
groups
,
padding
=
1
,
bias
=
False
)
self
.
bn2
=
norm_layer
(
med_planes
)
self
.
act2
=
act_layer
(
inplace
=
True
)
self
.
conv3
=
nn
.
Conv2d
(
med_planes
,
outplanes
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
bn3
=
norm_layer
(
outplanes
)
self
.
act3
=
act_layer
(
inplace
=
True
)
if
res_conv
:
self
.
residual_conv
=
nn
.
Conv2d
(
inplanes
,
outplanes
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
0
,
bias
=
False
)
self
.
residual_bn
=
norm_layer
(
outplanes
)
self
.
res_conv
=
res_conv
self
.
drop_block
=
drop_block
self
.
drop_path
=
drop_path
def
zero_init_last_bn
(
self
):
nn
.
init
.
zeros_
(
self
.
bn3
.
weight
)
def
forward
(
self
,
x
,
x_t
=
None
,
return_x_2
=
True
):
residual
=
x
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
if
self
.
drop_block
is
not
None
:
x
=
self
.
drop_block
(
x
)
x
=
self
.
act1
(
x
)
x
=
self
.
conv2
(
x
)
if
x_t
is
None
else
self
.
conv2
(
x
+
x_t
)
x
=
self
.
bn2
(
x
)
if
self
.
drop_block
is
not
None
:
x
=
self
.
drop_block
(
x
)
x2
=
self
.
act2
(
x
)
x
=
self
.
conv3
(
x2
)
x
=
self
.
bn3
(
x
)
if
self
.
drop_block
is
not
None
:
x
=
self
.
drop_block
(
x
)
if
self
.
drop_path
is
not
None
:
x
=
self
.
drop_path
(
x
)
if
self
.
res_conv
:
residual
=
self
.
residual_conv
(
residual
)
residual
=
self
.
residual_bn
(
residual
)
x
+=
residual
x
=
self
.
act3
(
x
)
if
return_x_2
:
return
x
,
x2
else
:
return
x
class
FCUDown
(
nn
.
Module
):
""" CNN feature maps -> Transformer patch embeddings
"""
def
__init__
(
self
,
inplanes
,
outplanes
,
dw_stride
,
act_layer
=
nn
.
GELU
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)):
super
(
FCUDown
,
self
).
__init__
()
self
.
dw_stride
=
dw_stride
self
.
conv_project
=
nn
.
Conv2d
(
inplanes
,
outplanes
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
sample_pooling
=
nn
.
AvgPool2d
(
kernel_size
=
dw_stride
,
stride
=
dw_stride
)
self
.
ln
=
norm_layer
(
outplanes
)
self
.
act
=
act_layer
()
def
forward
(
self
,
x
,
x_t
):
x
=
self
.
conv_project
(
x
)
# [N, C, H, W]
x
=
self
.
sample_pooling
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
ln
(
x
)
x
=
self
.
act
(
x
)
x
=
torch
.
cat
([
x_t
[:,
0
][:,
None
,
:],
x
],
dim
=
1
)
return
x
class
FCUUp
(
nn
.
Module
):
""" Transformer patch embeddings -> CNN feature maps
"""
def
__init__
(
self
,
inplanes
,
outplanes
,
up_stride
,
act_layer
=
nn
.
ReLU
,
norm_layer
=
partial
(
nn
.
BatchNorm2d
,
eps
=
1e-6
),):
super
(
FCUUp
,
self
).
__init__
()
self
.
up_stride
=
up_stride
self
.
conv_project
=
nn
.
Conv2d
(
inplanes
,
outplanes
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
bn
=
norm_layer
(
outplanes
)
self
.
act
=
act_layer
()
def
forward
(
self
,
x
,
H
,
W
):
B
,
_
,
C
=
x
.
shape
# [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
x_r
=
x
[:,
1
:].
transpose
(
1
,
2
).
reshape
(
B
,
C
,
H
,
W
)
x_r
=
self
.
act
(
self
.
bn
(
self
.
conv_project
(
x_r
)))
return
F
.
interpolate
(
x_r
,
size
=
(
H
*
self
.
up_stride
,
W
*
self
.
up_stride
))
class
Med_ConvBlock
(
nn
.
Module
):
""" special case for Convblock with down sampling,
"""
def
__init__
(
self
,
inplanes
,
act_layer
=
nn
.
ReLU
,
groups
=
1
,
norm_layer
=
partial
(
nn
.
BatchNorm2d
,
eps
=
1e-6
),
drop_block
=
None
,
drop_path
=
None
):
super
(
Med_ConvBlock
,
self
).
__init__
()
expansion
=
4
med_planes
=
inplanes
//
expansion
self
.
conv1
=
nn
.
Conv2d
(
inplanes
,
med_planes
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
bn1
=
norm_layer
(
med_planes
)
self
.
act1
=
act_layer
(
inplace
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
med_planes
,
med_planes
,
kernel_size
=
3
,
stride
=
1
,
groups
=
groups
,
padding
=
1
,
bias
=
False
)
self
.
bn2
=
norm_layer
(
med_planes
)
self
.
act2
=
act_layer
(
inplace
=
True
)
self
.
conv3
=
nn
.
Conv2d
(
med_planes
,
inplanes
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
bn3
=
norm_layer
(
inplanes
)
self
.
act3
=
act_layer
(
inplace
=
True
)
self
.
drop_block
=
drop_block
self
.
drop_path
=
drop_path
def
zero_init_last_bn
(
self
):
nn
.
init
.
zeros_
(
self
.
bn3
.
weight
)
def
forward
(
self
,
x
):
residual
=
x
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
if
self
.
drop_block
is
not
None
:
x
=
self
.
drop_block
(
x
)
x
=
self
.
act1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
bn2
(
x
)
if
self
.
drop_block
is
not
None
:
x
=
self
.
drop_block
(
x
)
x
=
self
.
act2
(
x
)
x
=
self
.
conv3
(
x
)
x
=
self
.
bn3
(
x
)
if
self
.
drop_block
is
not
None
:
x
=
self
.
drop_block
(
x
)
if
self
.
drop_path
is
not
None
:
x
=
self
.
drop_path
(
x
)
x
+=
residual
x
=
self
.
act3
(
x
)
return
x
class
ConvTransBlock
(
nn
.
Module
):
"""
Basic module for Conformer, keep feature maps for CNN block and patch embeddings for transformer encoder block
"""
def
__init__
(
self
,
inplanes
,
outplanes
,
res_conv
,
stride
,
dw_stride
,
embed_dim
,
num_heads
=
12
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
last_fusion
=
False
,
num_med_block
=
0
,
groups
=
1
):
super
(
ConvTransBlock
,
self
).
__init__
()
expansion
=
4
self
.
cnn_block
=
ConvBlock
(
inplanes
=
inplanes
,
outplanes
=
outplanes
,
res_conv
=
res_conv
,
stride
=
stride
,
groups
=
groups
)
if
last_fusion
:
self
.
fusion_block
=
ConvBlock
(
inplanes
=
outplanes
,
outplanes
=
outplanes
,
stride
=
2
,
res_conv
=
True
,
groups
=
groups
)
else
:
self
.
fusion_block
=
ConvBlock
(
inplanes
=
outplanes
,
outplanes
=
outplanes
,
groups
=
groups
)
if
num_med_block
>
0
:
self
.
med_block
=
[]
for
i
in
range
(
num_med_block
):
self
.
med_block
.
append
(
Med_ConvBlock
(
inplanes
=
outplanes
,
groups
=
groups
))
self
.
med_block
=
nn
.
ModuleList
(
self
.
med_block
)
self
.
squeeze_block
=
FCUDown
(
inplanes
=
outplanes
//
expansion
,
outplanes
=
embed_dim
,
dw_stride
=
dw_stride
)
self
.
expand_block
=
FCUUp
(
inplanes
=
embed_dim
,
outplanes
=
outplanes
//
expansion
,
up_stride
=
dw_stride
)
self
.
trans_block
=
Block
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
drop_path_rate
)
self
.
dw_stride
=
dw_stride
self
.
embed_dim
=
embed_dim
self
.
num_med_block
=
num_med_block
self
.
last_fusion
=
last_fusion
def
forward
(
self
,
x
,
x_t
):
x
,
x2
=
self
.
cnn_block
(
x
)
_
,
_
,
H
,
W
=
x2
.
shape
x_st
=
self
.
squeeze_block
(
x2
,
x_t
)
x_t
=
self
.
trans_block
(
x_st
+
x_t
)
if
self
.
num_med_block
>
0
:
for
m
in
self
.
med_block
:
x
=
m
(
x
)
x_t_r
=
self
.
expand_block
(
x_t
,
H
//
self
.
dw_stride
,
W
//
self
.
dw_stride
)
x
=
self
.
fusion_block
(
x
,
x_t_r
,
return_x_2
=
False
)
return
x
,
x_t
@
BACKBONES
.
register_module
()
class
Conformer
(
nn
.
Module
):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def
__init__
(
self
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
base_channel
=
64
,
channel_ratio
=
4
,
num_med_block
=
0
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_eval
=
True
,
frozen_stages
=
1
,
return_cls_token
=
False
):
# Transformer
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
num_features
=
self
.
embed_dim
=
embed_dim
# num_features for consistency with other models
self
.
return_cls_token
=
return_cls_token
self
.
norm_eval
=
norm_eval
self
.
frozen_stages
=
frozen_stages
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
self
.
trans_dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)]
# stochastic depth decay rule
# Classifiers
if
self
.
return_cls_token
:
self
.
trans_norm
=
nn
.
LayerNorm
(
embed_dim
)
self
.
trans_cls_head
=
nn
.
Linear
(
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
# self.pooling = nn.AdaptiveAvgPool2d(1)
# self.conv_cls_head = nn.Linear(1024, num_classes)
# Stem stage: get the feature maps by conv block (copied form resnet.py)
self
.
conv1
=
nn
.
Conv2d
(
in_chans
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
# 1 / 2 [112, 112]
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
act1
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
# 1 / 4 [56, 56]
# 1 stage
stage_1_channel
=
int
(
base_channel
*
channel_ratio
)
trans_dw_stride
=
patch_size
//
4
self
.
conv_1
=
ConvBlock
(
inplanes
=
64
,
outplanes
=
stage_1_channel
,
res_conv
=
True
,
stride
=
1
)
self
.
trans_patch_conv
=
nn
.
Conv2d
(
64
,
embed_dim
,
kernel_size
=
trans_dw_stride
,
stride
=
trans_dw_stride
,
padding
=
0
)
self
.
trans_1
=
Block
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
self
.
trans_dpr
[
0
],
)
# 2~4 stage
init_stage
=
2
fin_stage
=
depth
//
3
+
1
for
i
in
range
(
init_stage
,
fin_stage
):
self
.
add_module
(
'conv_trans_'
+
str
(
i
),
ConvTransBlock
(
stage_1_channel
,
stage_1_channel
,
False
,
1
,
dw_stride
=
trans_dw_stride
,
embed_dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop_rate
=
drop_rate
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
self
.
trans_dpr
[
i
-
1
],
num_med_block
=
num_med_block
)
)
stage_2_channel
=
int
(
base_channel
*
channel_ratio
*
2
)
# 5~8 stage
init_stage
=
fin_stage
# 5
fin_stage
=
fin_stage
+
depth
//
3
# 9
for
i
in
range
(
init_stage
,
fin_stage
):
s
=
2
if
i
==
init_stage
else
1
in_channel
=
stage_1_channel
if
i
==
init_stage
else
stage_2_channel
res_conv
=
True
if
i
==
init_stage
else
False
self
.
add_module
(
'conv_trans_'
+
str
(
i
),
ConvTransBlock
(
in_channel
,
stage_2_channel
,
res_conv
,
s
,
dw_stride
=
trans_dw_stride
//
2
,
embed_dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop_rate
=
drop_rate
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
self
.
trans_dpr
[
i
-
1
],
num_med_block
=
num_med_block
)
)
stage_3_channel
=
int
(
base_channel
*
channel_ratio
*
2
*
2
)
# 9~12 stage
init_stage
=
fin_stage
# 9
fin_stage
=
fin_stage
+
depth
//
3
# 13
for
i
in
range
(
init_stage
,
fin_stage
):
s
=
2
if
i
==
init_stage
else
1
in_channel
=
stage_2_channel
if
i
==
init_stage
else
stage_3_channel
res_conv
=
True
if
i
==
init_stage
else
False
last_fusion
=
True
if
i
==
depth
else
False
self
.
add_module
(
'conv_trans_'
+
str
(
i
),
ConvTransBlock
(
in_channel
,
stage_3_channel
,
res_conv
,
s
,
dw_stride
=
trans_dw_stride
//
4
,
embed_dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop_rate
=
drop_rate
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
self
.
trans_dpr
[
i
-
1
],
num_med_block
=
num_med_block
,
last_fusion
=
last_fusion
)
)
self
.
fin_stage
=
fin_stage
trunc_normal_
(
self
.
cls_token
,
std
=
.
02
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
'fan_out'
,
nonlinearity
=
'relu'
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
nn
.
init
.
constant_
(
m
.
weight
,
1.
)
nn
.
init
.
constant_
(
m
.
bias
,
0.
)
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
,
map_location
=
'cpu'
)
elif
pretrained
is
None
:
self
.
apply
(
self
.
_init_weights
)
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'cls_token'
,}
def
forward
(
self
,
x
):
output
=
[]
B
=
x
.
shape
[
0
]
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
# stem
x_base
=
self
.
maxpool
(
self
.
act1
(
self
.
bn1
(
self
.
conv1
(
x
))))
# 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56]
x
=
self
.
conv_1
(
x_base
,
return_x_2
=
False
)
x_t
=
self
.
trans_patch_conv
(
x_base
).
flatten
(
2
).
transpose
(
1
,
2
)
x_t
=
torch
.
cat
([
cls_tokens
,
x_t
],
dim
=
1
)
x_t
=
self
.
trans_1
(
x_t
)
# 2 ~ final
for
i
in
range
(
2
,
self
.
fin_stage
):
x
,
x_t
=
eval
(
'self.conv_trans_'
+
str
(
i
))(
x
,
x_t
)
if
i
in
[
4
,
8
,
11
,
12
]:
output
.
append
(
x
)
if
self
.
return_cls_token
:
return
tuple
(
output
),
self
.
trans_cls_head
(
self
.
trans_norm
(
x_t
[:,
[
0
,]]))
else
:
return
tuple
(
output
)
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
self
.
bn1
.
eval
()
for
m
in
[
self
.
conv1
,
self
.
bn1
]:
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
# for i in range(1, self.frozen_stages + 1):
# m = getattr(self, f'layer{i}')
# m.eval()
# for param in m.parameters():
# param.requires_grad = False
def
freeze_bn
(
self
,
m
):
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
eval
()
def
train
(
self
,
mode
=
True
):
"""Convert the model into training mode while keep normalization layer
freezed."""
super
(
Conformer
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
self
.
apply
(
self
.
freeze_bn
)
PyTorch/NLP/Conformer-main/mmdetection/mmdet/models/backbones/__init__.py
0 → 100644
View file @
142dcf29
from
.darknet
import
Darknet
from
.detectors_resnet
import
DetectoRS_ResNet
from
.detectors_resnext
import
DetectoRS_ResNeXt
from
.hourglass
import
HourglassNet
from
.hrnet
import
HRNet
from
.regnet
import
RegNet
from
.res2net
import
Res2Net
from
.resnest
import
ResNeSt
from
.resnet
import
ResNet
,
ResNetV1d
from
.resnext
import
ResNeXt
from
.ssd_vgg
import
SSDVGG
from
.trident_resnet
import
TridentResNet
from
.Conformer
import
Conformer
__all__
=
[
'RegNet'
,
'ResNet'
,
'ResNetV1d'
,
'ResNeXt'
,
'SSDVGG'
,
'HRNet'
,
'Res2Net'
,
'HourglassNet'
,
'DetectoRS_ResNet'
,
'DetectoRS_ResNeXt'
,
'Darknet'
,
'ResNeSt'
,
'TridentResNet'
,
'Conformer'
]
PyTorch/NLP/Conformer-main/mmdetection/mmdet/models/backbones/darknet.py
0 → 100644
View file @
142dcf29
# Copyright (c) 2019 Western Digital Corporation or its affiliates.
import
logging
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
,
constant_init
,
kaiming_init
from
mmcv.runner
import
load_checkpoint
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
..builder
import
BACKBONES
class
ResBlock
(
nn
.
Module
):
"""The basic residual block used in Darknet. Each ResBlock consists of two
ConvModules and the input is added to the final output. Each ConvModule is
composed of Conv, BN, and LeakyReLU. In YoloV3 paper, the first convLayer
has half of the number of the filters as much as the second convLayer. The
first convLayer has filter size of 1x1 and the second one has the filter
size of 3x3.
Args:
in_channels (int): The input channels. Must be even.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True)
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
"""
def
__init__
(
self
,
in_channels
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
requires_grad
=
True
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.1
)):
super
(
ResBlock
,
self
).
__init__
()
assert
in_channels
%
2
==
0
# ensure the in_channels is even
half_in_channels
=
in_channels
//
2
# shortcut
cfg
=
dict
(
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
self
.
conv1
=
ConvModule
(
in_channels
,
half_in_channels
,
1
,
**
cfg
)
self
.
conv2
=
ConvModule
(
half_in_channels
,
in_channels
,
3
,
padding
=
1
,
**
cfg
)
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
conv2
(
out
)
out
=
out
+
residual
return
out
@
BACKBONES
.
register_module
()
class
Darknet
(
nn
.
Module
):
"""Darknet backbone.
Args:
depth (int): Depth of Darknet. Currently only support 53.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True)
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
Example:
>>> from mmdet.models import Darknet
>>> import torch
>>> self = Darknet(depth=53)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 256, 52, 52)
(1, 512, 26, 26)
(1, 1024, 13, 13)
"""
# Dict(depth: (layers, channels))
arch_settings
=
{
53
:
((
1
,
2
,
8
,
8
,
4
),
((
32
,
64
),
(
64
,
128
),
(
128
,
256
),
(
256
,
512
),
(
512
,
1024
)))
}
def
__init__
(
self
,
depth
=
53
,
out_indices
=
(
3
,
4
,
5
),
frozen_stages
=-
1
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
requires_grad
=
True
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.1
),
norm_eval
=
True
):
super
(
Darknet
,
self
).
__init__
()
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
f
'invalid depth
{
depth
}
for darknet'
)
self
.
depth
=
depth
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
self
.
layers
,
self
.
channels
=
self
.
arch_settings
[
depth
]
cfg
=
dict
(
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
self
.
conv1
=
ConvModule
(
3
,
32
,
3
,
padding
=
1
,
**
cfg
)
self
.
cr_blocks
=
[
'conv1'
]
for
i
,
n_layers
in
enumerate
(
self
.
layers
):
layer_name
=
f
'conv_res_block
{
i
+
1
}
'
in_c
,
out_c
=
self
.
channels
[
i
]
self
.
add_module
(
layer_name
,
self
.
make_conv_res_block
(
in_c
,
out_c
,
n_layers
,
**
cfg
))
self
.
cr_blocks
.
append
(
layer_name
)
self
.
norm_eval
=
norm_eval
def
forward
(
self
,
x
):
outs
=
[]
for
i
,
layer_name
in
enumerate
(
self
.
cr_blocks
):
cr_block
=
getattr
(
self
,
layer_name
)
x
=
cr_block
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
logger
=
logging
.
getLogger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
kaiming_init
(
m
)
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
GroupNorm
)):
constant_init
(
m
,
1
)
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
for
i
in
range
(
self
.
frozen_stages
):
m
=
getattr
(
self
,
self
.
cr_blocks
[
i
])
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
Darknet
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
@
staticmethod
def
make_conv_res_block
(
in_channels
,
out_channels
,
res_repeat
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
requires_grad
=
True
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.1
)):
"""In Darknet backbone, ConvLayer is usually followed by ResBlock. This
function will make that. The Conv layers always have 3x3 filters with
stride=2. The number of the filters in Conv layer is the same as the
out channels of the ResBlock.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
res_repeat (int): The number of ResBlocks.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True)
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
"""
cfg
=
dict
(
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
model
=
nn
.
Sequential
()
model
.
add_module
(
'conv'
,
ConvModule
(
in_channels
,
out_channels
,
3
,
stride
=
2
,
padding
=
1
,
**
cfg
))
for
idx
in
range
(
res_repeat
):
model
.
add_module
(
'res{}'
.
format
(
idx
),
ResBlock
(
out_channels
,
**
cfg
))
return
model
Prev
1
…
7
8
9
10
11
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment