Unverified Commit 83e5a106 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add BEiT (#12994)



* First pass

* Make conversion script work

* Improve conversion script

* Fix bug, conversion script working

* Improve conversion script, implement BEiTFeatureExtractor

* Make conversion script work based on URL

* Improve conversion script

* Add tests, add documentation

* Fix bug in conversion script

* Fix another bug

* Add support for converting masked image modeling model

* Add support for converting masked image modeling

* Fix bug

* Add print statement for debugging

* Fix another bug

* Make conversion script finally work for masked image modeling models

* Move id2label for datasets to JSON files on the hub

* Make sure id's are read in as integers

* Add integration tests

* Make style & quality

* Fix test, add BEiT to README

* Apply suggestions from @sgugger's review

* Apply suggestions from code review

* Make quality

* Replace nielsr by microsoft in tests, add docs

* Rename BEiT to Beit

* Minor fix

* Fix docs of BeitForMaskedImageModeling
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 0dd1152c
...@@ -7,6 +7,11 @@ class ImageFeatureExtractionMixin: ...@@ -7,6 +7,11 @@ class ImageFeatureExtractionMixin:
requires_backends(self, ["vision"]) requires_backends(self, ["vision"])
class BeitFeatureExtractor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class CLIPFeatureExtractor: class CLIPFeatureExtractor:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"]) requires_backends(self, ["vision"])
......
# ImageNet 2012 id's to class names
id2label = {
0: "tench, Tinca tinca",
1: "goldfish, Carassius auratus",
2: "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
3: "tiger shark, Galeocerdo cuvieri",
4: "hammerhead, hammerhead shark",
5: "electric ray, crampfish, numbfish, torpedo",
6: "stingray",
7: "cock",
8: "hen",
9: "ostrich, Struthio camelus",
10: "brambling, Fringilla montifringilla",
11: "goldfinch, Carduelis carduelis",
12: "house finch, linnet, Carpodacus mexicanus",
13: "junco, snowbird",
14: "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
15: "robin, American robin, Turdus migratorius",
16: "bulbul",
17: "jay",
18: "magpie",
19: "chickadee",
20: "water ouzel, dipper",
21: "kite",
22: "bald eagle, American eagle, Haliaeetus leucocephalus",
23: "vulture",
24: "great grey owl, great gray owl, Strix nebulosa",
25: "European fire salamander, Salamandra salamandra",
26: "common newt, Triturus vulgaris",
27: "eft",
28: "spotted salamander, Ambystoma maculatum",
29: "axolotl, mud puppy, Ambystoma mexicanum",
30: "bullfrog, Rana catesbeiana",
31: "tree frog, tree-frog",
32: "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
33: "loggerhead, loggerhead turtle, Caretta caretta",
34: "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
35: "mud turtle",
36: "terrapin",
37: "box turtle, box tortoise",
38: "banded gecko",
39: "common iguana, iguana, Iguana iguana",
40: "American chameleon, anole, Anolis carolinensis",
41: "whiptail, whiptail lizard",
42: "agama",
43: "frilled lizard, Chlamydosaurus kingi",
44: "alligator lizard",
45: "Gila monster, Heloderma suspectum",
46: "green lizard, Lacerta viridis",
47: "African chameleon, Chamaeleo chamaeleon",
48: "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
49: "African crocodile, Nile crocodile, Crocodylus niloticus",
50: "American alligator, Alligator mississipiensis",
51: "triceratops",
52: "thunder snake, worm snake, Carphophis amoenus",
53: "ringneck snake, ring-necked snake, ring snake",
54: "hognose snake, puff adder, sand viper",
55: "green snake, grass snake",
56: "king snake, kingsnake",
57: "garter snake, grass snake",
58: "water snake",
59: "vine snake",
60: "night snake, Hypsiglena torquata",
61: "boa constrictor, Constrictor constrictor",
62: "rock python, rock snake, Python sebae",
63: "Indian cobra, Naja naja",
64: "green mamba",
65: "sea snake",
66: "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
67: "diamondback, diamondback rattlesnake, Crotalus adamanteus",
68: "sidewinder, horned rattlesnake, Crotalus cerastes",
69: "trilobite",
70: "harvestman, daddy longlegs, Phalangium opilio",
71: "scorpion",
72: "black and gold garden spider, Argiope aurantia",
73: "barn spider, Araneus cavaticus",
74: "garden spider, Aranea diademata",
75: "black widow, Latrodectus mactans",
76: "tarantula",
77: "wolf spider, hunting spider",
78: "tick",
79: "centipede",
80: "black grouse",
81: "ptarmigan",
82: "ruffed grouse, partridge, Bonasa umbellus",
83: "prairie chicken, prairie grouse, prairie fowl",
84: "peacock",
85: "quail",
86: "partridge",
87: "African grey, African gray, Psittacus erithacus",
88: "macaw",
89: "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
90: "lorikeet",
91: "coucal",
92: "bee eater",
93: "hornbill",
94: "hummingbird",
95: "jacamar",
96: "toucan",
97: "drake",
98: "red-breasted merganser, Mergus serrator",
99: "goose",
100: "black swan, Cygnus atratus",
101: "tusker",
102: "echidna, spiny anteater, anteater",
103: "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
104: "wallaby, brush kangaroo",
105: "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
106: "wombat",
107: "jellyfish",
108: "sea anemone, anemone",
109: "brain coral",
110: "flatworm, platyhelminth",
111: "nematode, nematode worm, roundworm",
112: "conch",
113: "snail",
114: "slug",
115: "sea slug, nudibranch",
116: "chiton, coat-of-mail shell, sea cradle, polyplacophore",
117: "chambered nautilus, pearly nautilus, nautilus",
118: "Dungeness crab, Cancer magister",
119: "rock crab, Cancer irroratus",
120: "fiddler crab",
121: "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
122: "American lobster, Northern lobster, Maine lobster, Homarus americanus",
123: "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
124: "crayfish, crawfish, crawdad, crawdaddy",
125: "hermit crab",
126: "isopod",
127: "white stork, Ciconia ciconia",
128: "black stork, Ciconia nigra",
129: "spoonbill",
130: "flamingo",
131: "little blue heron, Egretta caerulea",
132: "American egret, great white heron, Egretta albus",
133: "bittern",
134: "crane",
135: "limpkin, Aramus pictus",
136: "European gallinule, Porphyrio porphyrio",
137: "American coot, marsh hen, mud hen, water hen, Fulica americana",
138: "bustard",
139: "ruddy turnstone, Arenaria interpres",
140: "red-backed sandpiper, dunlin, Erolia alpina",
141: "redshank, Tringa totanus",
142: "dowitcher",
143: "oystercatcher, oyster catcher",
144: "pelican",
145: "king penguin, Aptenodytes patagonica",
146: "albatross, mollymawk",
147: "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
148: "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
149: "dugong, Dugong dugon",
150: "sea lion",
151: "Chihuahua",
152: "Japanese spaniel",
153: "Maltese dog, Maltese terrier, Maltese",
154: "Pekinese, Pekingese, Peke",
155: "Shih-Tzu",
156: "Blenheim spaniel",
157: "papillon",
158: "toy terrier",
159: "Rhodesian ridgeback",
160: "Afghan hound, Afghan",
161: "basset, basset hound",
162: "beagle",
163: "bloodhound, sleuthhound",
164: "bluetick",
165: "black-and-tan coonhound",
166: "Walker hound, Walker foxhound",
167: "English foxhound",
168: "redbone",
169: "borzoi, Russian wolfhound",
170: "Irish wolfhound",
171: "Italian greyhound",
172: "whippet",
173: "Ibizan hound, Ibizan Podenco",
174: "Norwegian elkhound, elkhound",
175: "otterhound, otter hound",
176: "Saluki, gazelle hound",
177: "Scottish deerhound, deerhound",
178: "Weimaraner",
179: "Staffordshire bullterrier, Staffordshire bull terrier",
180: "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
181: "Bedlington terrier",
182: "Border terrier",
183: "Kerry blue terrier",
184: "Irish terrier",
185: "Norfolk terrier",
186: "Norwich terrier",
187: "Yorkshire terrier",
188: "wire-haired fox terrier",
189: "Lakeland terrier",
190: "Sealyham terrier, Sealyham",
191: "Airedale, Airedale terrier",
192: "cairn, cairn terrier",
193: "Australian terrier",
194: "Dandie Dinmont, Dandie Dinmont terrier",
195: "Boston bull, Boston terrier",
196: "miniature schnauzer",
197: "giant schnauzer",
198: "standard schnauzer",
199: "Scotch terrier, Scottish terrier, Scottie",
200: "Tibetan terrier, chrysanthemum dog",
201: "silky terrier, Sydney silky",
202: "soft-coated wheaten terrier",
203: "West Highland white terrier",
204: "Lhasa, Lhasa apso",
205: "flat-coated retriever",
206: "curly-coated retriever",
207: "golden retriever",
208: "Labrador retriever",
209: "Chesapeake Bay retriever",
210: "German short-haired pointer",
211: "vizsla, Hungarian pointer",
212: "English setter",
213: "Irish setter, red setter",
214: "Gordon setter",
215: "Brittany spaniel",
216: "clumber, clumber spaniel",
217: "English springer, English springer spaniel",
218: "Welsh springer spaniel",
219: "cocker spaniel, English cocker spaniel, cocker",
220: "Sussex spaniel",
221: "Irish water spaniel",
222: "kuvasz",
223: "schipperke",
224: "groenendael",
225: "malinois",
226: "briard",
227: "kelpie",
228: "komondor",
229: "Old English sheepdog, bobtail",
230: "Shetland sheepdog, Shetland sheep dog, Shetland",
231: "collie",
232: "Border collie",
233: "Bouvier des Flandres, Bouviers des Flandres",
234: "Rottweiler",
235: "German shepherd, German shepherd dog, German police dog, alsatian",
236: "Doberman, Doberman pinscher",
237: "miniature pinscher",
238: "Greater Swiss Mountain dog",
239: "Bernese mountain dog",
240: "Appenzeller",
241: "EntleBucher",
242: "boxer",
243: "bull mastiff",
244: "Tibetan mastiff",
245: "French bulldog",
246: "Great Dane",
247: "Saint Bernard, St Bernard",
248: "Eskimo dog, husky",
249: "malamute, malemute, Alaskan malamute",
250: "Siberian husky",
251: "dalmatian, coach dog, carriage dog",
252: "affenpinscher, monkey pinscher, monkey dog",
253: "basenji",
254: "pug, pug-dog",
255: "Leonberg",
256: "Newfoundland, Newfoundland dog",
257: "Great Pyrenees",
258: "Samoyed, Samoyede",
259: "Pomeranian",
260: "chow, chow chow",
261: "keeshond",
262: "Brabancon griffon",
263: "Pembroke, Pembroke Welsh corgi",
264: "Cardigan, Cardigan Welsh corgi",
265: "toy poodle",
266: "miniature poodle",
267: "standard poodle",
268: "Mexican hairless",
269: "timber wolf, grey wolf, gray wolf, Canis lupus",
270: "white wolf, Arctic wolf, Canis lupus tundrarum",
271: "red wolf, maned wolf, Canis rufus, Canis niger",
272: "coyote, prairie wolf, brush wolf, Canis latrans",
273: "dingo, warrigal, warragal, Canis dingo",
274: "dhole, Cuon alpinus",
275: "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
276: "hyena, hyaena",
277: "red fox, Vulpes vulpes",
278: "kit fox, Vulpes macrotis",
279: "Arctic fox, white fox, Alopex lagopus",
280: "grey fox, gray fox, Urocyon cinereoargenteus",
281: "tabby, tabby cat",
282: "tiger cat",
283: "Persian cat",
284: "Siamese cat, Siamese",
285: "Egyptian cat",
286: "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
287: "lynx, catamount",
288: "leopard, Panthera pardus",
289: "snow leopard, ounce, Panthera uncia",
290: "jaguar, panther, Panthera onca, Felis onca",
291: "lion, king of beasts, Panthera leo",
292: "tiger, Panthera tigris",
293: "cheetah, chetah, Acinonyx jubatus",
294: "brown bear, bruin, Ursus arctos",
295: "American black bear, black bear, Ursus americanus, Euarctos americanus",
296: "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
297: "sloth bear, Melursus ursinus, Ursus ursinus",
298: "mongoose",
299: "meerkat, mierkat",
300: "tiger beetle",
301: "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
302: "ground beetle, carabid beetle",
303: "long-horned beetle, longicorn, longicorn beetle",
304: "leaf beetle, chrysomelid",
305: "dung beetle",
306: "rhinoceros beetle",
307: "weevil",
308: "fly",
309: "bee",
310: "ant, emmet, pismire",
311: "grasshopper, hopper",
312: "cricket",
313: "walking stick, walkingstick, stick insect",
314: "cockroach, roach",
315: "mantis, mantid",
316: "cicada, cicala",
317: "leafhopper",
318: "lacewing, lacewing fly",
319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
320: "damselfly",
321: "admiral",
322: "ringlet, ringlet butterfly",
323: "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
324: "cabbage butterfly",
325: "sulphur butterfly, sulfur butterfly",
326: "lycaenid, lycaenid butterfly",
327: "starfish, sea star",
328: "sea urchin",
329: "sea cucumber, holothurian",
330: "wood rabbit, cottontail, cottontail rabbit",
331: "hare",
332: "Angora, Angora rabbit",
333: "hamster",
334: "porcupine, hedgehog",
335: "fox squirrel, eastern fox squirrel, Sciurus niger",
336: "marmot",
337: "beaver",
338: "guinea pig, Cavia cobaya",
339: "sorrel",
340: "zebra",
341: "hog, pig, grunter, squealer, Sus scrofa",
342: "wild boar, boar, Sus scrofa",
343: "warthog",
344: "hippopotamus, hippo, river horse, Hippopotamus amphibius",
345: "ox",
346: "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
347: "bison",
348: "ram, tup",
349: "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
350: "ibex, Capra ibex",
351: "hartebeest",
352: "impala, Aepyceros melampus",
353: "gazelle",
354: "Arabian camel, dromedary, Camelus dromedarius",
355: "llama",
356: "weasel",
357: "mink",
358: "polecat, fitch, foulmart, foumart, Mustela putorius",
359: "black-footed ferret, ferret, Mustela nigripes",
360: "otter",
361: "skunk, polecat, wood pussy",
362: "badger",
363: "armadillo",
364: "three-toed sloth, ai, Bradypus tridactylus",
365: "orangutan, orang, orangutang, Pongo pygmaeus",
366: "gorilla, Gorilla gorilla",
367: "chimpanzee, chimp, Pan troglodytes",
368: "gibbon, Hylobates lar",
369: "siamang, Hylobates syndactylus, Symphalangus syndactylus",
370: "guenon, guenon monkey",
371: "patas, hussar monkey, Erythrocebus patas",
372: "baboon",
373: "macaque",
374: "langur",
375: "colobus, colobus monkey",
376: "proboscis monkey, Nasalis larvatus",
377: "marmoset",
378: "capuchin, ringtail, Cebus capucinus",
379: "howler monkey, howler",
380: "titi, titi monkey",
381: "spider monkey, Ateles geoffroyi",
382: "squirrel monkey, Saimiri sciureus",
383: "Madagascar cat, ring-tailed lemur, Lemur catta",
384: "indri, indris, Indri indri, Indri brevicaudatus",
385: "Indian elephant, Elephas maximus",
386: "African elephant, Loxodonta africana",
387: "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
388: "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
389: "barracouta, snoek",
390: "eel",
391: "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
392: "rock beauty, Holocanthus tricolor",
393: "anemone fish",
394: "sturgeon",
395: "gar, garfish, garpike, billfish, Lepisosteus osseus",
396: "lionfish",
397: "puffer, pufferfish, blowfish, globefish",
398: "abacus",
399: "abaya",
400: "academic gown, academic robe, judge's robe",
401: "accordion, piano accordion, squeeze box",
402: "acoustic guitar",
403: "aircraft carrier, carrier, flattop, attack aircraft carrier",
404: "airliner",
405: "airship, dirigible",
406: "altar",
407: "ambulance",
408: "amphibian, amphibious vehicle",
409: "analog clock",
410: "apiary, bee house",
411: "apron",
412: "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
413: "assault rifle, assault gun",
414: "backpack, back pack, knapsack, packsack, rucksack, haversack",
415: "bakery, bakeshop, bakehouse",
416: "balance beam, beam",
417: "balloon",
418: "ballpoint, ballpoint pen, ballpen, Biro",
419: "Band Aid",
420: "banjo",
421: "bannister, banister, balustrade, balusters, handrail",
422: "barbell",
423: "barber chair",
424: "barbershop",
425: "barn",
426: "barometer",
427: "barrel, cask",
428: "barrow, garden cart, lawn cart, wheelbarrow",
429: "baseball",
430: "basketball",
431: "bassinet",
432: "bassoon",
433: "bathing cap, swimming cap",
434: "bath towel",
435: "bathtub, bathing tub, bath, tub",
436: "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
437: "beacon, lighthouse, beacon light, pharos",
438: "beaker",
439: "bearskin, busby, shako",
440: "beer bottle",
441: "beer glass",
442: "bell cote, bell cot",
443: "bib",
444: "bicycle-built-for-two, tandem bicycle, tandem",
445: "bikini, two-piece",
446: "binder, ring-binder",
447: "binoculars, field glasses, opera glasses",
448: "birdhouse",
449: "boathouse",
450: "bobsled, bobsleigh, bob",
451: "bolo tie, bolo, bola tie, bola",
452: "bonnet, poke bonnet",
453: "bookcase",
454: "bookshop, bookstore, bookstall",
455: "bottlecap",
456: "bow",
457: "bow tie, bow-tie, bowtie",
458: "brass, memorial tablet, plaque",
459: "brassiere, bra, bandeau",
460: "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
461: "breastplate, aegis, egis",
462: "broom",
463: "bucket, pail",
464: "buckle",
465: "bulletproof vest",
466: "bullet train, bullet",
467: "butcher shop, meat market",
468: "cab, hack, taxi, taxicab",
469: "caldron, cauldron",
470: "candle, taper, wax light",
471: "cannon",
472: "canoe",
473: "can opener, tin opener",
474: "cardigan",
475: "car mirror",
476: "carousel, carrousel, merry-go-round, roundabout, whirligig",
477: "carpenter's kit, tool kit",
478: "carton",
479: "car wheel",
480: "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
481: "cassette",
482: "cassette player",
483: "castle",
484: "catamaran",
485: "CD player",
486: "cello, violoncello",
487: "cellular telephone, cellular phone, cellphone, cell, mobile phone",
488: "chain",
489: "chainlink fence",
490: "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
491: "chain saw, chainsaw",
492: "chest",
493: "chiffonier, commode",
494: "chime, bell, gong",
495: "china cabinet, china closet",
496: "Christmas stocking",
497: "church, church building",
498: "cinema, movie theater, movie theatre, movie house, picture palace",
499: "cleaver, meat cleaver, chopper",
500: "cliff dwelling",
501: "cloak",
502: "clog, geta, patten, sabot",
503: "cocktail shaker",
504: "coffee mug",
505: "coffeepot",
506: "coil, spiral, volute, whorl, helix",
507: "combination lock",
508: "computer keyboard, keypad",
509: "confectionery, confectionary, candy store",
510: "container ship, containership, container vessel",
511: "convertible",
512: "corkscrew, bottle screw",
513: "cornet, horn, trumpet, trump",
514: "cowboy boot",
515: "cowboy hat, ten-gallon hat",
516: "cradle",
517: "crane",
518: "crash helmet",
519: "crate",
520: "crib, cot",
521: "Crock Pot",
522: "croquet ball",
523: "crutch",
524: "cuirass",
525: "dam, dike, dyke",
526: "desk",
527: "desktop computer",
528: "dial telephone, dial phone",
529: "diaper, nappy, napkin",
530: "digital clock",
531: "digital watch",
532: "dining table, board",
533: "dishrag, dishcloth",
534: "dishwasher, dish washer, dishwashing machine",
535: "disk brake, disc brake",
536: "dock, dockage, docking facility",
537: "dogsled, dog sled, dog sleigh",
538: "dome",
539: "doormat, welcome mat",
540: "drilling platform, offshore rig",
541: "drum, membranophone, tympan",
542: "drumstick",
543: "dumbbell",
544: "Dutch oven",
545: "electric fan, blower",
546: "electric guitar",
547: "electric locomotive",
548: "entertainment center",
549: "envelope",
550: "espresso maker",
551: "face powder",
552: "feather boa, boa",
553: "file, file cabinet, filing cabinet",
554: "fireboat",
555: "fire engine, fire truck",
556: "fire screen, fireguard",
557: "flagpole, flagstaff",
558: "flute, transverse flute",
559: "folding chair",
560: "football helmet",
561: "forklift",
562: "fountain",
563: "fountain pen",
564: "four-poster",
565: "freight car",
566: "French horn, horn",
567: "frying pan, frypan, skillet",
568: "fur coat",
569: "garbage truck, dustcart",
570: "gasmask, respirator, gas helmet",
571: "gas pump, gasoline pump, petrol pump, island dispenser",
572: "goblet",
573: "go-kart",
574: "golf ball",
575: "golfcart, golf cart",
576: "gondola",
577: "gong, tam-tam",
578: "gown",
579: "grand piano, grand",
580: "greenhouse, nursery, glasshouse",
581: "grille, radiator grille",
582: "grocery store, grocery, food market, market",
583: "guillotine",
584: "hair slide",
585: "hair spray",
586: "half track",
587: "hammer",
588: "hamper",
589: "hand blower, blow dryer, blow drier, hair dryer, hair drier",
590: "hand-held computer, hand-held microcomputer",
591: "handkerchief, hankie, hanky, hankey",
592: "hard disc, hard disk, fixed disk",
593: "harmonica, mouth organ, harp, mouth harp",
594: "harp",
595: "harvester, reaper",
596: "hatchet",
597: "holster",
598: "home theater, home theatre",
599: "honeycomb",
600: "hook, claw",
601: "hoopskirt, crinoline",
602: "horizontal bar, high bar",
603: "horse cart, horse-cart",
604: "hourglass",
605: "iPod",
606: "iron, smoothing iron",
607: "jack-o'-lantern",
608: "jean, blue jean, denim",
609: "jeep, landrover",
610: "jersey, T-shirt, tee shirt",
611: "jigsaw puzzle",
612: "jinrikisha, ricksha, rickshaw",
613: "joystick",
614: "kimono",
615: "knee pad",
616: "knot",
617: "lab coat, laboratory coat",
618: "ladle",
619: "lampshade, lamp shade",
620: "laptop, laptop computer",
621: "lawn mower, mower",
622: "lens cap, lens cover",
623: "letter opener, paper knife, paperknife",
624: "library",
625: "lifeboat",
626: "lighter, light, igniter, ignitor",
627: "limousine, limo",
628: "liner, ocean liner",
629: "lipstick, lip rouge",
630: "Loafer",
631: "lotion",
632: "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
633: "loupe, jeweler's loupe",
634: "lumbermill, sawmill",
635: "magnetic compass",
636: "mailbag, postbag",
637: "mailbox, letter box",
638: "maillot",
639: "maillot, tank suit",
640: "manhole cover",
641: "maraca",
642: "marimba, xylophone",
643: "mask",
644: "matchstick",
645: "maypole",
646: "maze, labyrinth",
647: "measuring cup",
648: "medicine chest, medicine cabinet",
649: "megalith, megalithic structure",
650: "microphone, mike",
651: "microwave, microwave oven",
652: "military uniform",
653: "milk can",
654: "minibus",
655: "miniskirt, mini",
656: "minivan",
657: "missile",
658: "mitten",
659: "mixing bowl",
660: "mobile home, manufactured home",
661: "Model T",
662: "modem",
663: "monastery",
664: "monitor",
665: "moped",
666: "mortar",
667: "mortarboard",
668: "mosque",
669: "mosquito net",
670: "motor scooter, scooter",
671: "mountain bike, all-terrain bike, off-roader",
672: "mountain tent",
673: "mouse, computer mouse",
674: "mousetrap",
675: "moving van",
676: "muzzle",
677: "nail",
678: "neck brace",
679: "necklace",
680: "nipple",
681: "notebook, notebook computer",
682: "obelisk",
683: "oboe, hautboy, hautbois",
684: "ocarina, sweet potato",
685: "odometer, hodometer, mileometer, milometer",
686: "oil filter",
687: "organ, pipe organ",
688: "oscilloscope, scope, cathode-ray oscilloscope, CRO",
689: "overskirt",
690: "oxcart",
691: "oxygen mask",
692: "packet",
693: "paddle, boat paddle",
694: "paddlewheel, paddle wheel",
695: "padlock",
696: "paintbrush",
697: "pajama, pyjama, pj's, jammies",
698: "palace",
699: "panpipe, pandean pipe, syrinx",
700: "paper towel",
701: "parachute, chute",
702: "parallel bars, bars",
703: "park bench",
704: "parking meter",
705: "passenger car, coach, carriage",
706: "patio, terrace",
707: "pay-phone, pay-station",
708: "pedestal, plinth, footstall",
709: "pencil box, pencil case",
710: "pencil sharpener",
711: "perfume, essence",
712: "Petri dish",
713: "photocopier",
714: "pick, plectrum, plectron",
715: "pickelhaube",
716: "picket fence, paling",
717: "pickup, pickup truck",
718: "pier",
719: "piggy bank, penny bank",
720: "pill bottle",
721: "pillow",
722: "ping-pong ball",
723: "pinwheel",
724: "pirate, pirate ship",
725: "pitcher, ewer",
726: "plane, carpenter's plane, woodworking plane",
727: "planetarium",
728: "plastic bag",
729: "plate rack",
730: "plow, plough",
731: "plunger, plumber's helper",
732: "Polaroid camera, Polaroid Land camera",
733: "pole",
734: "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
735: "poncho",
736: "pool table, billiard table, snooker table",
737: "pop bottle, soda bottle",
738: "pot, flowerpot",
739: "potter's wheel",
740: "power drill",
741: "prayer rug, prayer mat",
742: "printer",
743: "prison, prison house",
744: "projectile, missile",
745: "projector",
746: "puck, hockey puck",
747: "punching bag, punch bag, punching ball, punchball",
748: "purse",
749: "quill, quill pen",
750: "quilt, comforter, comfort, puff",
751: "racer, race car, racing car",
752: "racket, racquet",
753: "radiator",
754: "radio, wireless",
755: "radio telescope, radio reflector",
756: "rain barrel",
757: "recreational vehicle, RV, R.V.",
758: "reel",
759: "reflex camera",
760: "refrigerator, icebox",
761: "remote control, remote",
762: "restaurant, eating house, eating place, eatery",
763: "revolver, six-gun, six-shooter",
764: "rifle",
765: "rocking chair, rocker",
766: "rotisserie",
767: "rubber eraser, rubber, pencil eraser",
768: "rugby ball",
769: "rule, ruler",
770: "running shoe",
771: "safe",
772: "safety pin",
773: "saltshaker, salt shaker",
774: "sandal",
775: "sarong",
776: "sax, saxophone",
777: "scabbard",
778: "scale, weighing machine",
779: "school bus",
780: "schooner",
781: "scoreboard",
782: "screen, CRT screen",
783: "screw",
784: "screwdriver",
785: "seat belt, seatbelt",
786: "sewing machine",
787: "shield, buckler",
788: "shoe shop, shoe-shop, shoe store",
789: "shoji",
790: "shopping basket",
791: "shopping cart",
792: "shovel",
793: "shower cap",
794: "shower curtain",
795: "ski",
796: "ski mask",
797: "sleeping bag",
798: "slide rule, slipstick",
799: "sliding door",
800: "slot, one-armed bandit",
801: "snorkel",
802: "snowmobile",
803: "snowplow, snowplough",
804: "soap dispenser",
805: "soccer ball",
806: "sock",
807: "solar dish, solar collector, solar furnace",
808: "sombrero",
809: "soup bowl",
810: "space bar",
811: "space heater",
812: "space shuttle",
813: "spatula",
814: "speedboat",
815: "spider web, spider's web",
816: "spindle",
817: "sports car, sport car",
818: "spotlight, spot",
819: "stage",
820: "steam locomotive",
821: "steel arch bridge",
822: "steel drum",
823: "stethoscope",
824: "stole",
825: "stone wall",
826: "stopwatch, stop watch",
827: "stove",
828: "strainer",
829: "streetcar, tram, tramcar, trolley, trolley car",
830: "stretcher",
831: "studio couch, day bed",
832: "stupa, tope",
833: "submarine, pigboat, sub, U-boat",
834: "suit, suit of clothes",
835: "sundial",
836: "sunglass",
837: "sunglasses, dark glasses, shades",
838: "sunscreen, sunblock, sun blocker",
839: "suspension bridge",
840: "swab, swob, mop",
841: "sweatshirt",
842: "swimming trunks, bathing trunks",
843: "swing",
844: "switch, electric switch, electrical switch",
845: "syringe",
846: "table lamp",
847: "tank, army tank, armored combat vehicle, armoured combat vehicle",
848: "tape player",
849: "teapot",
850: "teddy, teddy bear",
851: "television, television system",
852: "tennis ball",
853: "thatch, thatched roof",
854: "theater curtain, theatre curtain",
855: "thimble",
856: "thresher, thrasher, threshing machine",
857: "throne",
858: "tile roof",
859: "toaster",
860: "tobacco shop, tobacconist shop, tobacconist",
861: "toilet seat",
862: "torch",
863: "totem pole",
864: "tow truck, tow car, wrecker",
865: "toyshop",
866: "tractor",
867: "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
868: "tray",
869: "trench coat",
870: "tricycle, trike, velocipede",
871: "trimaran",
872: "tripod",
873: "triumphal arch",
874: "trolleybus, trolley coach, trackless trolley",
875: "trombone",
876: "tub, vat",
877: "turnstile",
878: "typewriter keyboard",
879: "umbrella",
880: "unicycle, monocycle",
881: "upright, upright piano",
882: "vacuum, vacuum cleaner",
883: "vase",
884: "vault",
885: "velvet",
886: "vending machine",
887: "vestment",
888: "viaduct",
889: "violin, fiddle",
890: "volleyball",
891: "waffle iron",
892: "wall clock",
893: "wallet, billfold, notecase, pocketbook",
894: "wardrobe, closet, press",
895: "warplane, military plane",
896: "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
897: "washer, automatic washer, washing machine",
898: "water bottle",
899: "water jug",
900: "water tower",
901: "whiskey jug",
902: "whistle",
903: "wig",
904: "window screen",
905: "window shade",
906: "Windsor tie",
907: "wine bottle",
908: "wing",
909: "wok",
910: "wooden spoon",
911: "wool, woolen, woollen",
912: "worm fence, snake fence, snake-rail fence, Virginia fence",
913: "wreck",
914: "yawl",
915: "yurt",
916: "web site, website, internet site, site",
917: "comic book",
918: "crossword puzzle, crossword",
919: "street sign",
920: "traffic light, traffic signal, stoplight",
921: "book jacket, dust cover, dust jacket, dust wrapper",
922: "menu",
923: "plate",
924: "guacamole",
925: "consomme",
926: "hot pot, hotpot",
927: "trifle",
928: "ice cream, icecream",
929: "ice lolly, lolly, lollipop, popsicle",
930: "French loaf",
931: "bagel, beigel",
932: "pretzel",
933: "cheeseburger",
934: "hotdog, hot dog, red hot",
935: "mashed potato",
936: "head cabbage",
937: "broccoli",
938: "cauliflower",
939: "zucchini, courgette",
940: "spaghetti squash",
941: "acorn squash",
942: "butternut squash",
943: "cucumber, cuke",
944: "artichoke, globe artichoke",
945: "bell pepper",
946: "cardoon",
947: "mushroom",
948: "Granny Smith",
949: "strawberry",
950: "orange",
951: "lemon",
952: "fig",
953: "pineapple, ananas",
954: "banana",
955: "jackfruit, jak, jack",
956: "custard apple",
957: "pomegranate",
958: "hay",
959: "carbonara",
960: "chocolate sauce, chocolate syrup",
961: "dough",
962: "meat loaf, meatloaf",
963: "pizza, pizza pie",
964: "potpie",
965: "burrito",
966: "red wine",
967: "espresso",
968: "cup",
969: "eggnog",
970: "alp",
971: "bubble",
972: "cliff, drop, drop-off",
973: "coral reef",
974: "geyser",
975: "lakeside, lakeshore",
976: "promontory, headland, head, foreland",
977: "sandbar, sand bar",
978: "seashore, coast, seacoast, sea-coast",
979: "valley, vale",
980: "volcano",
981: "ballplayer, baseball player",
982: "groom, bridegroom",
983: "scuba diver",
984: "rapeseed",
985: "daisy",
986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
987: "corn",
988: "acorn",
989: "hip, rose hip, rosehip",
990: "buckeye, horse chestnut, conker",
991: "coral fungus",
992: "agaric",
993: "gyromitra",
994: "stinkhorn, carrion fungus",
995: "earthstar",
996: "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
997: "bolete",
998: "ear, spike, capitulum",
999: "toilet tissue, toilet paper, bathroom tissue",
}
...@@ -76,6 +76,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -76,6 +76,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
("ViTConfig", "ViTForImageClassification"), ("ViTConfig", "ViTForImageClassification"),
("DeiTConfig", "('DeiTForImageClassification', 'DeiTForImageClassificationWithTeacher')"), ("DeiTConfig", "('DeiTForImageClassification', 'DeiTForImageClassificationWithTeacher')"),
("BeitConfig", "BeitForImageClassification"),
] ]
) )
...@@ -261,6 +262,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -261,6 +262,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
MODEL_MAPPING_NAMES = OrderedDict( MODEL_MAPPING_NAMES = OrderedDict(
[ [
("BeitConfig", "BeitModel"),
("RemBertConfig", "RemBertModel"), ("RemBertConfig", "RemBertModel"),
("VisualBertConfig", "VisualBertModel"), ("VisualBertConfig", "VisualBertModel"),
("CanineConfig", "CanineModel"), ("CanineConfig", "CanineModel"),
......
# coding=utf-8
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision
from .test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import BeitFeatureExtractor
class BeitFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=20,
do_center_crop=True,
crop_size=18,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
def prepare_feat_extract_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_center_crop": self.do_center_crop,
"crop_size": self.crop_size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
}
@require_torch
@require_vision
class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
feature_extraction_class = BeitFeatureExtractor if is_vision_available() else None
def setUp(self):
self.feature_extract_tester = BeitFeatureExtractionTester(self)
@property
def feat_extract_dict(self):
return self.feature_extract_tester.prepare_feat_extract_dict()
def test_feat_extract_properties(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
self.assertTrue(hasattr(feature_extractor, "center_crop"))
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_batch_feature(self):
pass
def test_call_pil(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
),
)
# Test batched
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
),
)
def test_call_numpy(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random numpy tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
),
)
# Test batched
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
),
)
def test_call_pytorch(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
),
)
# Test batched
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
),
)
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch BEiT model. """
import inspect
import unittest
from transformers import BeitConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available():
import torch
from torch import nn
from transformers import MODEL_MAPPING, BeitForImageClassification, BeitForMaskedImageModeling, BeitModel
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
if is_vision_available():
from PIL import Image
from transformers import BeitFeatureExtractor
class BeitModelTester:
def __init__(
self,
parent,
vocab_size=100,
batch_size=13,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
scope=None,
):
self.parent = parent
self.vocab_size = 100
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return BeitConfig(
vocab_size=self.vocab_size,
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values, labels):
model = BeitModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_masked_lm(self, config, pixel_values, labels):
model = BeitForMaskedImageModeling(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected sequence length = num_patches
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = BeitForImageClassification(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
pixel_values,
labels,
) = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class BeitModelTest(ModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as BEiT does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling) if is_torch_available() else ()
)
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = BeitModelTester(self)
self.config_tester = ConfigTester(self, config_class=BeitConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_inputs_embeds(self):
# BEiT does not use inputs_embeds
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_training(self):
if not self.model_tester.is_training:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING):
continue
# we don't test BeitForMaskedImageModeling
if model_class.__name__ == "BeitForMaskedImageModeling":
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
# we skip lambda parameters as these require special initial values
# determined by config.layer_scale_init_value
if "lambda" in name:
continue
if param.requires_grad:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# in BEiT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# BEiT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in BEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = BeitModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_vision
class BeitModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224") if is_vision_available() else None
)
@slow
def test_inference_image_classification_head_imagenet_1k(self):
model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224").to(torch_device)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
outputs = model(**inputs)
logits = outputs.logits
# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor([-1.2385, -1.0987, -1.0108]).to(torch_device)
self.assertTrue(torch.allclose(logits[0, :3], expected_slice, atol=1e-4))
expected_class_idx = 281
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
@slow
def test_inference_image_classification_head_imagenet_22k(self):
model = BeitForImageClassification.from_pretrained("microsoft/beit-large-patch16-224-pt22k-ft22k").to(
torch_device
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
outputs = model(**inputs)
logits = outputs.logits
# verify the logits
expected_shape = torch.Size((1, 21841))
self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor([1.6881, -0.2787, 0.5901]).to(torch_device)
self.assertTrue(torch.allclose(logits[0, :3], expected_slice, atol=1e-4))
expected_class_idx = 2396
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
...@@ -97,6 +97,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ ...@@ -97,6 +97,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# should **not** be the rule. # should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping # models to ignore for model xxx mapping
"BeitForMaskedImageModeling",
"CLIPTextModel", "CLIPTextModel",
"CLIPVisionModel", "CLIPVisionModel",
"FlaxCLIPTextModel", "FlaxCLIPTextModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment