test_models.py 36 KB
Newer Older
1
import contextlib
2
3
4
import functools
import operator
import os
5
import pkgutil
6
import platform
7
import sys
8
import warnings
9
from collections import OrderedDict
10
from tempfile import TemporaryDirectory
11
from typing import Any
12
13

import pytest
14
import torch
15
import torch.fx
16
import torch.nn as nn
17
from _utils_internal import get_relative_path
18
from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
19
20
from PIL import Image
from torchvision import models, transforms
21
from torchvision.models import get_model_builder, list_models
22

23

24
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
25
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
26
27


28
def list_model_fns(module):
29
    return [get_model_builder(name) for name in list_models(module)]
30
31


32
33
34
35
36
37
38
39
40
41
42
def _get_image(input_shape, real_image, device):
    """This routine loads a real or random image based on `real_image` argument.
    Currently, the real image is utilized for the following list of models:
    - `retinanet_resnet50_fpn`,
    - `retinanet_resnet50_fpn_v2`,
    - `keypointrcnn_resnet50_fpn`,
    - `fasterrcnn_resnet50_fpn`,
    - `fasterrcnn_resnet50_fpn_v2`,
    - `fcos_resnet50_fpn`,
    - `maskrcnn_resnet50_fpn`,
    - `maskrcnn_resnet50_fpn_v2`,
Aidyn-A's avatar
Aidyn-A committed
43
    in `test_classification_model` and `test_detection_model`.
44
45
46
    To do so, a keyword argument `real_image` was added to the abovelisted models in `_model_params`
    """
    if real_image:
47
48
49
        # TODO: Maybe unify file discovery logic with test_image.py
        GRACE_HOPPER = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
50
        )
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        img = Image.open(GRACE_HOPPER)

        original_width, original_height = img.size

        # make the image square
        img = img.crop((0, 0, original_width, original_width))
        img = img.resize(input_shape[1:3])

        convert_tensor = transforms.ToTensor()
        image = convert_tensor(img)
        assert tuple(image.size()) == input_shape
        return image.to(device=device)

    # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
    return torch.rand(input_shape).to(device=device)


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@pytest.fixture
def disable_weight_loading(mocker):
    """When testing models, the two slowest operations are the downloading of the weights to a file and loading them
    into the model. Unless, you want to test against specific weights, these steps can be disabled without any
    drawbacks.

    Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse
    through all models in `torchvision.models` and will patch all occurrences of the function
    `download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be
    no-ops.

    .. warning:

        Loaded models are still executable as normal, but will always have random weights. Make sure to not use this
        fixture if you want to compare the model output against reference values.

    """
    starting_point = models
    function_name = "load_state_dict_from_url"
    method_name = "load_state_dict"

    module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")}
    targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"}
    for name in module_names:
        module = sys.modules.get(name)
        if not module:
            continue

        if function_name in module.__dict__:
            targets.add(f"{module.__name__}.{function_name}")

        targets.update(
            {
                f"{module.__name__}.{obj.__name__}.{method_name}"
                for obj in module.__dict__.values()
                if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__
            }
        )

    for target in targets:
        # See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details
        with contextlib.suppress(AttributeError):
            mocker.patch(target)


114
115
116
117
118
119
def _get_expected_file(name=None):
    # Determine expected file based on environment
    expected_file_base = get_relative_path(os.path.realpath(__file__), "expect")

    # Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
    # We hardcode it here to avoid having to re-generate the reference files
120
    expected_file = os.path.join(expected_file_base, "ModelTester.test_" + name)
121
122
123
124
125
126
127
128
129
130
131
132
    expected_file += "_expect.pkl"

    if not ACCEPT and not os.path.exists(expected_file):
        raise RuntimeError(
            f"No expect file exists for {os.path.basename(expected_file)} in {expected_file}; "
            "to accept the current output, re-run the failing test after setting the EXPECTTEST_ACCEPT "
            "env variable. For example: EXPECTTEST_ACCEPT=1 pytest test/test_models.py -k alexnet"
        )

    return expected_file


133
def _assert_expected(output, name, prec=None, atol=None, rtol=None):
134
135
136
137
138
139
140
141
142
143
144
    """Test that a python value matches the recorded contents of a file
    based on a "check" name. The value must be
    pickable with `torch.save`. This file
    is placed in the 'expect' directory in the same directory
    as the test script. You can automatically update the recorded test
    output using an EXPECTTEST_ACCEPT=1 env variable.
    """
    expected_file = _get_expected_file(name)

    if ACCEPT:
        filename = {os.path.basename(expected_file)}
145
        print(f"Accepting updated output for {filename}:\n\n{output}")
146
147
148
149
        torch.save(output, expected_file)
        MAX_PICKLE_SIZE = 50 * 1000  # 50 KB
        binary_size = os.path.getsize(expected_file)
        if binary_size > MAX_PICKLE_SIZE:
150
            raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb")
151
152
    else:
        expected = torch.load(expected_file)
153
154
        rtol = rtol or prec  # keeping prec param for legacy reason, but could be removed ideally
        atol = atol or prec
155
        torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False, check_device=False)
156
157


158
def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None):
159
160
    """Check that a nn.Module's results in TorchScript match eager and that it can be exported"""

161
162
163
164
165
166
167
    def get_export_import_copy(m):
        """Save and load a TorchScript model"""
        with TemporaryDirectory() as dir:
            path = os.path.join(dir, "script.pt")
            m.save(path)
            imported = torch.jit.load(path)
        return imported
168
169

    sm = torch.jit.script(nn_module)
Aidyn-A's avatar
Aidyn-A committed
170
    sm.eval()
171

172
173
    if eager_out is None:
        with torch.no_grad(), freeze_rng_state():
174
            eager_out = nn_module(*args)
175

176
    with torch.no_grad(), freeze_rng_state():
177
178
179
180
181
        script_out = sm(*args)
        if unwrapper:
            script_out = unwrapper(script_out)

    torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)
182
183
184
185
186
187
188
189

    m_import = get_export_import_copy(sm)
    with torch.no_grad(), freeze_rng_state():
        imported_script_out = m_import(*args)
        if unwrapper:
            imported_script_out = unwrapper(imported_script_out)

    torch.testing.assert_close(script_out, imported_script_out, atol=3e-4, rtol=3e-4)
190
191


192
def _check_fx_compatible(model, inputs, eager_out=None):
193
    model_fx = torch.fx.symbolic_trace(model)
194
195
    if eager_out is None:
        eager_out = model(inputs)
Aidyn-A's avatar
Aidyn-A committed
196
197
    with torch.no_grad(), freeze_rng_state():
        fx_out = model_fx(inputs)
198
    torch.testing.assert_close(eager_out, fx_out)
199
200


201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def _check_input_backprop(model, inputs):
    if isinstance(inputs, list):
        requires_grad = list()
        for inp in inputs:
            requires_grad.append(inp.requires_grad)
            inp.requires_grad_(True)
    else:
        requires_grad = inputs.requires_grad
        inputs.requires_grad_(True)

    out = model(inputs)

    if isinstance(out, dict):
        out["out"].sum().backward()
    else:
        if isinstance(out[0], dict):
            out[0]["scores"].sum().backward()
        else:
            out[0].sum().backward()

    if isinstance(inputs, list):
        for i, inp in enumerate(inputs):
            assert inputs[i].grad is not None
            inp.requires_grad_(requires_grad[i])
    else:
        assert inputs.grad is not None
        inputs.requires_grad_(requires_grad)


230
231
232
# If 'unwrapper' is provided it will be called with the script model outputs
# before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode
233
script_model_unwrapper = {
234
235
    "googlenet": lambda x: x.logits,
    "inception_v3": lambda x: x.logits,
236
    "fasterrcnn_resnet50_fpn": lambda x: x[1],
237
    "fasterrcnn_resnet50_fpn_v2": lambda x: x[1],
238
    "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
239
    "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
240
    "maskrcnn_resnet50_fpn": lambda x: x[1],
241
    "maskrcnn_resnet50_fpn_v2": lambda x: x[1],
242
243
    "keypointrcnn_resnet50_fpn": lambda x: x[1],
    "retinanet_resnet50_fpn": lambda x: x[1],
244
    "retinanet_resnet50_fpn_v2": lambda x: x[1],
245
    "ssd300_vgg16": lambda x: x[1],
246
    "ssdlite320_mobilenet_v3_large": lambda x: x[1],
Hu Ye's avatar
Hu Ye committed
247
    "fcos_resnet50_fpn": lambda x: x[1],
248
}
249
250


251
252
253
254
255
256
257
258
259
260
261
262
263
264
# The following models exhibit flaky numerics under autocast in _test_*_model harnesses.
# This may be caused by the harness environment (e.g. num classes, input initialization
# via torch.rand), and does not prove autocast is unsuitable when training with real data
# (autocast has been used successfully with real data for some of these models).
# TODO:  investigate why autocast numerics are flaky in the harnesses.
#
# For the following models, _test_*_model harnesses skip numerical checks on outputs when
# trying autocast. However, they still try an autocasted forward pass, so they still ensure
# autocast coverage suffices to prevent dtype errors in each model.
autocast_flaky_numerics = (
    "inception_v3",
    "resnet101",
    "resnet152",
    "wide_resnet101_2",
265
266
    "deeplabv3_resnet50",
    "deeplabv3_resnet101",
267
    "deeplabv3_mobilenet_v3_large",
268
269
    "fcn_resnet50",
    "fcn_resnet101",
270
    "lraspp_mobilenet_v3_large",
271
    "maskrcnn_resnet50_fpn",
272
    "maskrcnn_resnet50_fpn_v2",
273
    "keypointrcnn_resnet50_fpn",
274
275
)

276
277
278
# The tests for the following quantized models are flaky possibly due to inconsistent
# rounding errors in different platforms. For this reason the input/output consistency
# tests under test_quantized_classification_model will be skipped for the following models.
279
quantized_flaky_models = ("inception_v3", "resnet50")
280

281

282
283
284
# The following contains configuration parameters for all models which are used by
# the _test_*_model methods.
_model_params = {
285
    "inception_v3": {"input_shape": (1, 3, 299, 299), "init_weights": True},
286
287
288
289
290
291
    "retinanet_resnet50_fpn": {
        "num_classes": 20,
        "score_thresh": 0.01,
        "min_size": 224,
        "max_size": 224,
        "input_shape": (3, 224, 224),
292
        "real_image": True,
293
    },
294
295
296
297
298
299
    "retinanet_resnet50_fpn_v2": {
        "num_classes": 20,
        "score_thresh": 0.01,
        "min_size": 224,
        "max_size": 224,
        "input_shape": (3, 224, 224),
300
        "real_image": True,
301
    },
302
303
304
305
    "keypointrcnn_resnet50_fpn": {
        "num_classes": 2,
        "min_size": 224,
        "max_size": 224,
306
        "box_score_thresh": 0.17,
307
        "input_shape": (3, 224, 224),
308
        "real_image": True,
309
    },
310
311
312
313
314
    "fasterrcnn_resnet50_fpn": {
        "num_classes": 20,
        "min_size": 224,
        "max_size": 224,
        "input_shape": (3, 224, 224),
315
        "real_image": True,
316
    },
317
318
319
320
321
    "fasterrcnn_resnet50_fpn_v2": {
        "num_classes": 20,
        "min_size": 224,
        "max_size": 224,
        "input_shape": (3, 224, 224),
322
        "real_image": True,
323
    },
Hu Ye's avatar
Hu Ye committed
324
325
326
327
328
329
    "fcos_resnet50_fpn": {
        "num_classes": 2,
        "score_thresh": 0.05,
        "min_size": 224,
        "max_size": 224,
        "input_shape": (3, 224, 224),
330
        "real_image": True,
Hu Ye's avatar
Hu Ye committed
331
    },
332
333
334
335
336
    "maskrcnn_resnet50_fpn": {
        "num_classes": 10,
        "min_size": 224,
        "max_size": 224,
        "input_shape": (3, 224, 224),
337
        "real_image": True,
338
    },
339
340
341
342
343
    "maskrcnn_resnet50_fpn_v2": {
        "num_classes": 10,
        "min_size": 224,
        "max_size": 224,
        "input_shape": (3, 224, 224),
344
        "real_image": True,
345
    },
346
347
    "fasterrcnn_mobilenet_v3_large_fpn": {
        "box_score_thresh": 0.02076,
348
    },
349
350
351
352
    "fasterrcnn_mobilenet_v3_large_320_fpn": {
        "box_score_thresh": 0.02076,
        "rpn_pre_nms_top_n_test": 1000,
        "rpn_post_nms_top_n_test": 1000,
353
    },
354
355
356
357
    "vit_h_14": {
        "image_size": 56,
        "input_shape": (1, 3, 56, 56),
    },
358
359
360
    "mvit_v1_b": {
        "input_shape": (1, 3, 16, 224, 224),
    },
361
362
363
    "mvit_v2_s": {
        "input_shape": (1, 3, 16, 224, 224),
    },
364
365
366
    "s3d": {
        "input_shape": (1, 3, 16, 224, 224),
    },
367
    "googlenet": {"init_weights": True},
368
}
369
370
371
372
373
# speeding up slow models:
slow_models = [
    "convnext_base",
    "convnext_large",
    "resnext101_32x8d",
374
    "resnext101_64x4d",
375
376
377
378
379
380
381
382
383
384
    "wide_resnet101_2",
    "efficientnet_b6",
    "efficientnet_b7",
    "efficientnet_v2_m",
    "efficientnet_v2_l",
    "regnet_y_16gf",
    "regnet_y_32gf",
    "regnet_y_128gf",
    "regnet_x_16gf",
    "regnet_x_32gf",
Joao Gomes's avatar
Joao Gomes committed
385
    "swin_t",
386
387
    "swin_s",
    "swin_b",
Local State's avatar
Local State committed
388
389
390
    "swin_v2_t",
    "swin_v2_s",
    "swin_v2_b",
391
392
393
]
for m in slow_models:
    _model_params[m] = {"input_shape": (1, 3, 64, 64)}
394
395


396
# skip big models to reduce memory usage on CI test. We can exclude combinations of (platform-system, device).
397
skipped_big_models = {
398
399
    "vit_h_14": {("Windows", "cpu"), ("Windows", "cuda")},
    "regnet_y_128gf": {("Windows", "cpu"), ("Windows", "cuda")},
400
401
    "mvit_v1_b": {("Windows", "cuda"), ("Linux", "cuda")},
    "mvit_v2_s": {("Windows", "cuda"), ("Linux", "cuda")},
402
403
}

404
405
406
407
408
409
410
411
412
413
414

def is_skippable(model_name, device):
    if model_name not in skipped_big_models:
        return False

    platform_system = platform.system()
    device_name = str(device).split(":")[0]

    return (platform_system, device_name) in skipped_big_models[model_name]


415
416
417
418
419
420
# The following contains configuration and expected values to be used tests that are model specific
_model_tests_values = {
    "retinanet_resnet50_fpn": {
        "max_trainable": 5,
        "n_trn_params_per_layer": [36, 46, 65, 78, 88, 89],
    },
421
422
423
424
    "retinanet_resnet50_fpn_v2": {
        "max_trainable": 5,
        "n_trn_params_per_layer": [44, 74, 131, 170, 200, 203],
    },
425
426
427
428
429
430
431
432
    "keypointrcnn_resnet50_fpn": {
        "max_trainable": 5,
        "n_trn_params_per_layer": [48, 58, 77, 90, 100, 101],
    },
    "fasterrcnn_resnet50_fpn": {
        "max_trainable": 5,
        "n_trn_params_per_layer": [30, 40, 59, 72, 82, 83],
    },
433
434
435
436
    "fasterrcnn_resnet50_fpn_v2": {
        "max_trainable": 5,
        "n_trn_params_per_layer": [50, 80, 137, 176, 206, 209],
    },
437
438
439
440
    "maskrcnn_resnet50_fpn": {
        "max_trainable": 5,
        "n_trn_params_per_layer": [42, 52, 71, 84, 94, 95],
    },
441
442
443
444
    "maskrcnn_resnet50_fpn_v2": {
        "max_trainable": 5,
        "n_trn_params_per_layer": [66, 96, 153, 192, 222, 225],
    },
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
    "fasterrcnn_mobilenet_v3_large_fpn": {
        "max_trainable": 6,
        "n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],
    },
    "fasterrcnn_mobilenet_v3_large_320_fpn": {
        "max_trainable": 6,
        "n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],
    },
    "ssd300_vgg16": {
        "max_trainable": 5,
        "n_trn_params_per_layer": [45, 51, 57, 63, 67, 71],
    },
    "ssdlite320_mobilenet_v3_large": {
        "max_trainable": 6,
        "n_trn_params_per_layer": [96, 99, 138, 200, 239, 257, 266],
    },
Hu Ye's avatar
Hu Ye committed
461
462
463
464
    "fcos_resnet50_fpn": {
        "max_trainable": 5,
        "n_trn_params_per_layer": [54, 64, 83, 96, 106, 107],
    },
465
466
467
}


Anirudh's avatar
Anirudh committed
468
469
470
471
472
473
474
475
476
477
def _make_sliced_model(model, stop_layer):
    layers = OrderedDict()
    for name, layer in model.named_children():
        layers[name] = layer
        if name == stop_layer:
            break
    new_model = torch.nn.Sequential(layers)
    return new_model


478
479
@pytest.mark.parametrize("model_fn", [models.densenet121, models.densenet169, models.densenet201, models.densenet161])
def test_memory_efficient_densenet(model_fn):
Anirudh's avatar
Anirudh committed
480
481
482
    input_shape = (1, 3, 300, 300)
    x = torch.rand(input_shape)

483
    model1 = model_fn(num_classes=50, memory_efficient=True)
Anirudh's avatar
Anirudh committed
484
    params = model1.state_dict()
485
    num_params = sum(x.numel() for x in model1.parameters())
Anirudh's avatar
Anirudh committed
486
487
488
    model1.eval()
    out1 = model1(x)
    out1.sum().backward()
489
    num_grad = sum(x.grad.numel() for x in model1.parameters() if x.grad is not None)
Anirudh's avatar
Anirudh committed
490

491
    model2 = model_fn(num_classes=50, memory_efficient=False)
Anirudh's avatar
Anirudh committed
492
493
494
495
496
497
498
    model2.load_state_dict(params)
    model2.eval()
    out2 = model2(x)

    assert num_params == num_grad
    torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5)

499
500
501
    _check_input_backprop(model1, x)
    _check_input_backprop(model2, x)

Anirudh's avatar
Anirudh committed
502

503
504
505
@pytest.mark.parametrize("dilate_layer_2", (True, False))
@pytest.mark.parametrize("dilate_layer_3", (True, False))
@pytest.mark.parametrize("dilate_layer_4", (True, False))
Anirudh's avatar
Anirudh committed
506
507
def test_resnet_dilation(dilate_layer_2, dilate_layer_3, dilate_layer_4):
    # TODO improve tests to also check that each layer has the right dimensionality
508
    model = models.resnet50(replace_stride_with_dilation=(dilate_layer_2, dilate_layer_3, dilate_layer_4))
Anirudh's avatar
Anirudh committed
509
510
511
512
513
514
515
516
517
    model = _make_sliced_model(model, stop_layer="layer4")
    model.eval()
    x = torch.rand(1, 3, 224, 224)
    out = model(x)
    f = 2 ** sum((dilate_layer_2, dilate_layer_3, dilate_layer_4))
    assert out.shape == (1, 2048, 7 * f, 7 * f)


def test_mobilenet_v2_residual_setting():
518
    model = models.mobilenet_v2(inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]])
Anirudh's avatar
Anirudh committed
519
520
521
522
523
524
    model.eval()
    x = torch.rand(1, 3, 224, 224)
    out = model(x)
    assert out.shape[-1] == 1000


525
526
527
@pytest.mark.parametrize("model_fn", [models.mobilenet_v2, models.mobilenet_v3_large, models.mobilenet_v3_small])
def test_mobilenet_norm_layer(model_fn):
    model = model_fn()
Anirudh's avatar
Anirudh committed
528
529
530
    assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules())

    def get_gn(num_channels):
531
        return nn.GroupNorm(1, num_channels)
Anirudh's avatar
Anirudh committed
532

533
    model = model_fn(norm_layer=get_gn)
534
    assert not (any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
Anirudh's avatar
Anirudh committed
535
536
537
538
539
    assert any(isinstance(x, nn.GroupNorm) for x in model.modules())


def test_inception_v3_eval():
    kwargs = {}
540
541
542
    kwargs["transform_input"] = True
    kwargs["aux_logits"] = True
    kwargs["init_weights"] = False
Anirudh's avatar
Anirudh committed
543
544
545
546
547
548
549
    name = "inception_v3"
    model = models.Inception3(**kwargs)
    model.aux_logits = False
    model.AuxLogits = None
    model = model.eval()
    x = torch.rand(1, 3, 299, 299)
    _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
550
    _check_input_backprop(model, x)
Anirudh's avatar
Anirudh committed
551
552
553


def test_fasterrcnn_double():
554
    model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None)
Anirudh's avatar
Anirudh committed
555
556
557
558
559
560
561
562
563
564
565
    model.double()
    model.eval()
    input_shape = (3, 300, 300)
    x = torch.rand(input_shape, dtype=torch.float64)
    model_input = [x]
    out = model(model_input)
    assert model_input[0] is x
    assert len(out) == 1
    assert "boxes" in out[0]
    assert "scores" in out[0]
    assert "labels" in out[0]
566
    _check_input_backprop(model, model_input)
Anirudh's avatar
Anirudh committed
567
568
569
570


def test_googlenet_eval():
    kwargs = {}
571
572
573
    kwargs["transform_input"] = True
    kwargs["aux_logits"] = True
    kwargs["init_weights"] = False
Anirudh's avatar
Anirudh committed
574
575
576
577
578
579
580
581
    name = "googlenet"
    model = models.GoogLeNet(**kwargs)
    model.aux_logits = False
    model.aux1 = None
    model.aux2 = None
    model = model.eval()
    x = torch.rand(1, 3, 224, 224)
    _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
582
    _check_input_backprop(model, x)
Anirudh's avatar
Anirudh committed
583
584
585
586
587
588
589
590
591
592


@needs_cuda
def test_fasterrcnn_switch_devices():
    def checkOut(out):
        assert len(out) == 1
        assert "boxes" in out[0]
        assert "scores" in out[0]
        assert "labels" in out[0]

593
    model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None)
Anirudh's avatar
Anirudh committed
594
595
596
    model.cuda()
    model.eval()
    input_shape = (3, 300, 300)
597
    x = torch.rand(input_shape, device="cuda")
Anirudh's avatar
Anirudh committed
598
599
600
601
602
603
604
    model_input = [x]
    out = model(model_input)
    assert model_input[0] is x

    checkOut(out)

    with torch.cuda.amp.autocast():
605
        out = model(model_input)
606

Anirudh's avatar
Anirudh committed
607
    checkOut(out)
608

609
610
    _check_input_backprop(model, model_input)

Anirudh's avatar
Anirudh committed
611
612
613
614
    # now switch to cpu and make sure it works
    model.cpu()
    x = x.cpu()
    out_cpu = model([x])
615

Anirudh's avatar
Anirudh committed
616
    checkOut(out_cpu)
617

618
619
    _check_input_backprop(model, [x])

620

Anirudh's avatar
Anirudh committed
621
def test_generalizedrcnn_transform_repr():
622

Anirudh's avatar
Anirudh committed
623
624
625
    min_size, max_size = 224, 299
    image_mean = [0.485, 0.456, 0.406]
    image_std = [0.229, 0.224, 0.225]
626

627
628
629
    t = models.detection.transform.GeneralizedRCNNTransform(
        min_size=min_size, max_size=max_size, image_mean=image_mean, image_std=image_std
    )
630

Anirudh's avatar
Anirudh committed
631
    # Check integrity of object __repr__ attribute
632
633
    expected_string = "GeneralizedRCNNTransform("
    _indent = "\n    "
634
635
    expected_string += f"{_indent}Normalize(mean={image_mean}, std={image_std})"
    expected_string += f"{_indent}Resize(min_size=({min_size},), max_size={max_size}, "
Anirudh's avatar
Anirudh committed
636
637
    expected_string += "mode='bilinear')\n)"
    assert t.__repr__() == expected_string
638
639


640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
test_vit_conv_stem_configs = [
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=64),
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=128),
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=128),
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=256),
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=256),
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=512),
]


def vitc_b_16(**kwargs: Any):
    return models.VisionTransformer(
        image_size=224,
        patch_size=16,
        num_layers=12,
        num_heads=12,
        hidden_dim=768,
        mlp_dim=3072,
        conv_stem_configs=test_vit_conv_stem_configs,
        **kwargs,
    )


@pytest.mark.parametrize("model_fn", [vitc_b_16])
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_vitc_models(model_fn, dev):
    test_classification_model(model_fn, dev)


669
@pytest.mark.parametrize("model_fn", list_model_fns(models))
670
@pytest.mark.parametrize("dev", cpu_and_gpu())
671
def test_classification_model(model_fn, dev):
Anirudh's avatar
Anirudh committed
672
673
    set_rng_seed(0)
    defaults = {
674
675
        "num_classes": 50,
        "input_shape": (1, 3, 224, 224),
Anirudh's avatar
Anirudh committed
676
    }
677
    model_name = model_fn.__name__
678
    if SKIP_BIG_MODEL and is_skippable(model_name, dev):
679
        pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
Anirudh's avatar
Anirudh committed
680
    kwargs = {**defaults, **_model_params.get(model_name, {})}
681
    num_classes = kwargs.get("num_classes")
682
    input_shape = kwargs.pop("input_shape")
683
    real_image = kwargs.pop("real_image", False)
Anirudh's avatar
Anirudh committed
684

685
    model = model_fn(**kwargs)
Anirudh's avatar
Anirudh committed
686
    model.eval().to(device=dev)
687
    x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
Anirudh's avatar
Anirudh committed
688
    out = model(x)
689
    _assert_expected(out.cpu(), model_name, prec=1e-3)
690
    assert out.shape[-1] == num_classes
691
692
    _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
    _check_fx_compatible(model, x, eager_out=out)
Anirudh's avatar
Anirudh committed
693

694
    if dev == "cuda":
Anirudh's avatar
Anirudh committed
695
696
697
698
699
700
        with torch.cuda.amp.autocast():
            out = model(x)
            # See autocast_flaky_numerics comment at top of file.
            if model_name not in autocast_flaky_numerics:
                _assert_expected(out.cpu(), model_name, prec=0.1)
            assert out.shape[-1] == 50
701

702
703
    _check_input_backprop(model, x)

704

705
@pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation))
706
@pytest.mark.parametrize("dev", cpu_and_gpu())
707
def test_segmentation_model(model_fn, dev):
Anirudh's avatar
Anirudh committed
708
709
    set_rng_seed(0)
    defaults = {
710
        "num_classes": 10,
711
        "weights_backbone": None,
712
        "input_shape": (1, 3, 32, 32),
Anirudh's avatar
Anirudh committed
713
    }
714
    model_name = model_fn.__name__
Anirudh's avatar
Anirudh committed
715
    kwargs = {**defaults, **_model_params.get(model_name, {})}
716
    input_shape = kwargs.pop("input_shape")
Anirudh's avatar
Anirudh committed
717

718
    model = model_fn(**kwargs)
Anirudh's avatar
Anirudh committed
719
720
721
    model.eval().to(device=dev)
    # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
    x = torch.rand(input_shape).to(device=dev)
Aidyn-A's avatar
Aidyn-A committed
722
723
    with torch.no_grad(), freeze_rng_state():
        out = model(x)
Anirudh's avatar
Anirudh committed
724
725
726
727
728
729
730
731
732
733
734
735
736
737

    def check_out(out):
        prec = 0.01
        try:
            # We first try to assert the entire output if possible. This is not
            # only the best way to assert results but also handles the cases
            # where we need to create a new expected result.
            _assert_expected(out.cpu(), model_name, prec=prec)
        except AssertionError:
            # Unfortunately some segmentation models are flaky with autocast
            # so instead of validating the probability scores, check that the class
            # predictions match.
            expected_file = _get_expected_file(model_name)
            expected = torch.load(expected_file)
738
739
740
            torch.testing.assert_close(
                out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec, check_device=False
            )
Anirudh's avatar
Anirudh committed
741
742
743
744
            return False  # Partial validation performed

        return True  # Full validation performed

745
    full_validation = check_out(out["out"])
Anirudh's avatar
Anirudh committed
746

747
748
    _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
    _check_fx_compatible(model, x, eager_out=out)
Anirudh's avatar
Anirudh committed
749

750
    if dev == "cuda":
Aidyn-A's avatar
Aidyn-A committed
751
        with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
752
            out = model(x)
Anirudh's avatar
Anirudh committed
753
754
            # See autocast_flaky_numerics comment at top of file.
            if model_name not in autocast_flaky_numerics:
755
                full_validation &= check_out(out["out"])
Anirudh's avatar
Anirudh committed
756
757

    if not full_validation:
758
        msg = (
759
            f"The output of {test_segmentation_model.__name__} could only be partially validated. "
760
761
            "This is likely due to unit-test flakiness, but you may "
            "want to do additional manual checks if you made "
762
            "significant changes to the codebase."
763
        )
Anirudh's avatar
Anirudh committed
764
765
        warnings.warn(msg, RuntimeWarning)
        pytest.skip(msg)
766

767
768
    _check_input_backprop(model, x)

769

770
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
771
@pytest.mark.parametrize("dev", cpu_and_gpu())
772
def test_detection_model(model_fn, dev):
Anirudh's avatar
Anirudh committed
773
774
    set_rng_seed(0)
    defaults = {
775
        "num_classes": 50,
776
        "weights_backbone": None,
777
        "input_shape": (3, 300, 300),
Anirudh's avatar
Anirudh committed
778
    }
779
    model_name = model_fn.__name__
Anirudh's avatar
Anirudh committed
780
    kwargs = {**defaults, **_model_params.get(model_name, {})}
781
    input_shape = kwargs.pop("input_shape")
782
    real_image = kwargs.pop("real_image", False)
Anirudh's avatar
Anirudh committed
783

784
    model = model_fn(**kwargs)
Anirudh's avatar
Anirudh committed
785
    model.eval().to(device=dev)
786
    x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
Anirudh's avatar
Anirudh committed
787
    model_input = [x]
Aidyn-A's avatar
Aidyn-A committed
788
789
    with torch.no_grad(), freeze_rng_state():
        out = model(model_input)
Anirudh's avatar
Anirudh committed
790
791
    assert model_input[0] is x

792
    def check_out(out):
Anirudh's avatar
Anirudh committed
793
794
795
        assert len(out) == 1

        def compact(tensor):
796
            tensor = tensor.cpu()
Anirudh's avatar
Anirudh committed
797
798
799
800
801
802
803
804
805
806
807
808
809
810
            size = tensor.size()
            elements_per_sample = functools.reduce(operator.mul, size[1:], 1)
            if elements_per_sample > 30:
                return compute_mean_std(tensor)
            else:
                return subsample_tensor(tensor)

        def subsample_tensor(tensor):
            num_elems = tensor.size(0)
            num_samples = 20
            if num_elems <= num_samples:
                return tensor

            ith_index = num_elems // num_samples
811
            return tensor[ith_index - 1 :: ith_index]
Anirudh's avatar
Anirudh committed
812
813
814
815
816
817
818
819
820

        def compute_mean_std(tensor):
            # can't compute mean of integral tensor
            tensor = tensor.to(torch.double)
            mean = torch.mean(tensor)
            std = torch.std(tensor)
            return {"mean": mean, "std": std}

        output = map_nested_tensor_object(out, tensor_map_fn=compact)
821
        prec = 0.01
Anirudh's avatar
Anirudh committed
822
823
824
825
826
827
828
829
830
831
832
833
        try:
            # We first try to assert the entire output if possible. This is not
            # only the best way to assert results but also handles the cases
            # where we need to create a new expected result.
            _assert_expected(output, model_name, prec=prec)
        except AssertionError:
            # Unfortunately detection models are flaky due to the unstable sort
            # in NMS. If matching across all outputs fails, use the same approach
            # as in NMSTester.test_nms_cuda to see if this is caused by duplicate
            # scores.
            expected_file = _get_expected_file(model_name)
            expected = torch.load(expected_file)
834
835
836
            torch.testing.assert_close(
                output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False
            )
Anirudh's avatar
Anirudh committed
837
838
839
840
841
842
843
844
845
846

            # Note: Fmassa proposed turning off NMS by adapting the threshold
            # and then using the Hungarian algorithm as in DETR to find the
            # best match between output and expected boxes and eliminate some
            # of the flakiness. Worth exploring.
            return False  # Partial validation performed

        return True  # Full validation performed

    full_validation = check_out(out)
847
    _check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
Anirudh's avatar
Anirudh committed
848

849
    if dev == "cuda":
Aidyn-A's avatar
Aidyn-A committed
850
        with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
Anirudh's avatar
Anirudh committed
851
852
853
            out = model(model_input)
            # See autocast_flaky_numerics comment at top of file.
            if model_name not in autocast_flaky_numerics:
854
                full_validation &= check_out(out)
Anirudh's avatar
Anirudh committed
855
856

    if not full_validation:
857
        msg = (
858
            f"The output of {test_detection_model.__name__} could only be partially validated. "
859
860
            "This is likely due to unit-test flakiness, but you may "
            "want to do additional manual checks if you made "
861
            "significant changes to the codebase."
862
        )
Anirudh's avatar
Anirudh committed
863
864
        warnings.warn(msg, RuntimeWarning)
        pytest.skip(msg)
865

866
867
    _check_input_backprop(model, model_input)

868

869
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
870
def test_detection_model_validation(model_fn):
Anirudh's avatar
Anirudh committed
871
    set_rng_seed(0)
872
    model = model_fn(num_classes=50, weights=None, weights_backbone=None)
Anirudh's avatar
Anirudh committed
873
874
875
876
    input_shape = (3, 300, 300)
    x = [torch.rand(input_shape)]

    # validate that targets are present in training
877
    with pytest.raises(AssertionError):
Anirudh's avatar
Anirudh committed
878
879
880
        model(x)

    # validate type
881
    targets = [{"boxes": 0.0}]
882
    with pytest.raises(AssertionError):
Anirudh's avatar
Anirudh committed
883
884
885
886
        model(x, targets=targets)

    # validate boxes shape
    for boxes in (torch.rand((4,)), torch.rand((1, 5))):
887
        targets = [{"boxes": boxes}]
888
        with pytest.raises(AssertionError):
Anirudh's avatar
Anirudh committed
889
890
891
892
            model(x, targets=targets)

    # validate that no degenerate boxes are present
    boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]])
893
    targets = [{"boxes": boxes}]
894
    with pytest.raises(AssertionError):
Anirudh's avatar
Anirudh committed
895
        model(x, targets=targets)
896

897

898
@pytest.mark.parametrize("model_fn", list_model_fns(models.video))
899
@pytest.mark.parametrize("dev", cpu_and_gpu())
900
def test_video_model(model_fn, dev):
901
    set_rng_seed(0)
Anirudh's avatar
Anirudh committed
902
903
    # the default input shape is
    # bs * num_channels * clip_len * h *w
904
905
906
907
    defaults = {
        "input_shape": (1, 3, 4, 112, 112),
        "num_classes": 50,
    }
908
    model_name = model_fn.__name__
909
    if SKIP_BIG_MODEL and is_skippable(model_name, dev):
910
        pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
911
912
913
    kwargs = {**defaults, **_model_params.get(model_name, {})}
    num_classes = kwargs.get("num_classes")
    input_shape = kwargs.pop("input_shape")
Anirudh's avatar
Anirudh committed
914
    # test both basicblock and Bottleneck
915
    model = model_fn(**kwargs)
Anirudh's avatar
Anirudh committed
916
917
918
919
    model.eval().to(device=dev)
    # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
    x = torch.rand(input_shape).to(device=dev)
    out = model(x)
920
    _assert_expected(out.cpu(), model_name, prec=1e-5)
921
    assert out.shape[-1] == num_classes
922
923
    _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
    _check_fx_compatible(model, x, eager_out=out)
924
    assert out.shape[-1] == num_classes
Anirudh's avatar
Anirudh committed
925

926
    if dev == "cuda":
Anirudh's avatar
Anirudh committed
927
928
        with torch.cuda.amp.autocast():
            out = model(x)
929
930
931
932
            # See autocast_flaky_numerics comment at top of file.
            if model_name not in autocast_flaky_numerics:
                _assert_expected(out.cpu(), model_name, prec=0.1)
            assert out.shape[-1] == num_classes
933

934
935
    _check_input_backprop(model, x)

936

937
938
939
940
941
942
943
@pytest.mark.skipif(
    not (
        "fbgemm" in torch.backends.quantized.supported_engines
        and "qnnpack" in torch.backends.quantized.supported_engines
    ),
    reason="This Pytorch Build has not been built with fbgemm and qnnpack",
)
944
@pytest.mark.parametrize("model_fn", list_model_fns(models.quantization))
945
def test_quantized_classification_model(model_fn):
946
    set_rng_seed(0)
947
    defaults = {
948
        "num_classes": 5,
949
950
        "input_shape": (1, 3, 224, 224),
        "quantize": True,
951
    }
952
    model_name = model_fn.__name__
953
    kwargs = {**defaults, **_model_params.get(model_name, {})}
954
    input_shape = kwargs.pop("input_shape")
955
956

    # First check if quantize=True provides models that can run with input data
957
    model = model_fn(**kwargs)
958
    model.eval()
959
    x = torch.rand(input_shape)
960
961
962
    out = model(x)

    if model_name not in quantized_flaky_models:
963
        _assert_expected(out.cpu(), model_name + "_quantized", prec=2e-2)
964
        assert out.shape[-1] == 5
965
966
967
968
969
970
971
        _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
        _check_fx_compatible(model, x, eager_out=out)
    else:
        try:
            torch.jit.script(model)
        except Exception as e:
            raise AssertionError("model cannot be scripted.") from e
972

973
    kwargs["quantize"] = False
974
    for eval_mode in [True, False]:
975
        model = model_fn(**kwargs)
976
977
        if eval_mode:
            model.eval()
978
            model.qconfig = torch.ao.quantization.default_qconfig
979
980
        else:
            model.train()
981
            model.qconfig = torch.ao.quantization.default_qat_qconfig
982

983
        model.fuse_model(is_qat=not eval_mode)
984
        if eval_mode:
985
            torch.ao.quantization.prepare(model, inplace=True)
986
        else:
987
            torch.ao.quantization.prepare_qat(model, inplace=True)
988
989
            model.eval()

990
        torch.ao.quantization.convert(model, inplace=True)
991
992


993
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
994
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
995
    model_name = model_fn.__name__
996
997
998
    max_trainable = _model_tests_values[model_name]["max_trainable"]
    n_trainable_params = []
    for trainable_layers in range(0, max_trainable + 1):
999
        model = model_fn(weights=None, weights_backbone="DEFAULT", trainable_backbone_layers=trainable_layers)
1000
1001
1002
1003
1004

        n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad]))
    assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"]


1005
@needs_cuda
1006
@pytest.mark.parametrize("model_fn", list_model_fns(models.optical_flow))
1007
@pytest.mark.parametrize("scripted", (False, True))
1008
def test_raft(model_fn, scripted):
1009
1010
1011
1012
1013
1014
1015
1016
1017

    torch.manual_seed(0)

    # We need very small images, otherwise the pickle size would exceed the 50KB
    # As a resut we need to override the correlation pyramid to not downsample
    # too much, otherwise we would get nan values (effective H and W would be
    # reduced to 1)
    corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)

1018
    model = model_fn(corr_block=corr_block).eval().to("cuda")
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
    if scripted:
        model = torch.jit.script(model)

    bs = 1
    img1 = torch.rand(bs, 3, 80, 72).cuda()
    img2 = torch.rand(bs, 3, 80, 72).cuda()

    preds = model(img1, img2)
    flow_pred = preds[-1]
    # Tolerance is fairly high, but there are 2 * H * W outputs to check
    # The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different
1030
    _assert_expected(flow_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1)
1031
1032


1033
if __name__ == "__main__":
1034
    pytest.main([__file__])