test_fused_optimizer.py 27.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

import copy
6
from contextlib import nullcontext
7

8
import pytest
9
10
11
12
import torch
from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
13
from transformer_engine.common.recipe import DelayedScaling
14
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
15
16
17
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
18
from transformer_engine.pytorch.utils import gpu_autocast_ctx
19
20
21

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
22
23


24
25
26
class TestFusedOptimizer:

    def setup_method(self, *, iters: int = 7) -> None:
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        self.iters = iters
        torch.manual_seed(9876)

    def gen_param_optim(self, tensors, options, tst_options=None):

        # Adding this to make backward compatible with existing tests. Just in
        # case "tst_options" are not provided, it gets a copy of options
        # which contains the parameters for the reference optimizer
        if tst_options == None:
            tst_options = options

        ref_param = []
        tst_param = []
        for tensor in tensors:
            ref_param.append(torch.nn.Parameter(tensor.clone()))
            tst_param.append(torch.nn.Parameter(tensor.clone()))

        ref_optim = self.ref_optim(ref_param, **options)
        tst_optim = self.fused_optim(tst_param, **tst_options)

        return (ref_param, tst_param, ref_optim, tst_optim)

    def gen_grad(self, ref_param, tst_param):
        for p_ref, p_tst in zip(ref_param, tst_param):
            p_ref.grad = torch.rand_like(p_ref)
            p_tst.grad = p_ref.grad

    def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
        half_grads = []
        for p_ref, p_tst in zip(ref_param, tst_param):
            half_grads.append(torch.rand_like(p_ref).half())
            p_ref.grad = half_grads[-1].float() / scale
        return half_grads

    def gen_single_type_test(
        self, param_type=torch.float, device="cuda", *, skip_assert: bool = False
    ):
        nelem = 278011

        # Some ref and test optimizers may require different set of options.
        # This is a quick workaround to add that functionality while making
        # minimum changes in existing code.
        # If there is no "tst_options" field provided, safe to initialize
        # the test optimizer with the parameters of reference optimizer.
        if not hasattr(self, "tst_options"):
            self.tst_options = self.options

        tensor = torch.rand(nelem, dtype=param_type, device=device)

        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
            [tensor], self.options, self.tst_options
        )

        for i in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()
            if skip_assert:
                return
            torch.testing.assert_close(ref_param, tst_param)


class TestFusedAdam(TestFusedOptimizer):

91
92
    def setup_method(self) -> None:
        super().setup_method()
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        self.options = {
            "lr": 5e-4,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "weight_decay": 0,
            "amsgrad": False,
        }
        self.ref_optim = torch.optim.Adam
        self.fused_optim = te.optimizers.FusedAdam

    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)

    # NOTE(mkozuki): Current threshold values look too small for BFloat16.
    # TODO(mkozuki): Refactor `TestFusedOptimizer`
    def test_half(self):
        self.gen_single_type_test(param_type=torch.float16, skip_assert=True)

    def test_bfloat16(self):
        self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True)

    def test_multi_params(self):
        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]

        tensors = []
        for size in sizes:
            tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
120
        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, self.options)
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

        for i in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()

            torch.testing.assert_close(ref_param, tst_param)

    def test_adam_option(self):
        nelem = 1
        adam_option = {
            "lr": 0.01,
            "betas": (0.6, 0.9),
            "eps": 3e-06,
            "weight_decay": 0,
            "amsgrad": False,
        }

        tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
140
        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

        for i in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()

            torch.testing.assert_close(ref_param, tst_param)

    def test_frozen_model(self):
        nelem = 1
        adam_option = {
            "lr": 0.01,
            "betas": (0.6, 0.9),
            "eps": 3e-06,
            "weight_decay": 0,
            "amsgrad": False,
        }

        tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
160
        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)
161
162
163
164
165
166
167
168
169
170
171

        # Add an empty param group which may occur for pipeline parallel p-tuning
        tst_optim.add_param_group({"params": []})

        for i in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()

            torch.testing.assert_close(ref_param, tst_param)

172
173
174
175
176
177
178
179
180
    def gen_precision_aware_test(
        self,
        use_fp8_params,
        param_dtype,
        use_master_weights,
        master_weight_dtype,
        grad_dtype,
        exp_avg_dtype,
        exp_avg_sq_dtype,
181
        store_param_remainders=False,
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        model_rtol=None,
        model_atol=None,
        master_rtol=None,
        master_atol=None,
        skip_assert=False,
    ):
        build_model_context = nullcontext
        build_model_context_args = {}
        if use_fp8_params:
            build_model_context = fp8_model_init
            build_model_context_args["enabled"] = True

        with build_model_context(**build_model_context_args):
            model = MultiheadAttention(
                hidden_size=1024,
                num_attention_heads=16,
                layer_number=1,
                params_dtype=param_dtype,
                fuse_qkv_params=True,
            ).cuda()

        ref_params = []
        model_params = []

        for p in model.parameters():
            if p.requires_grad:
                ref_params.append(p.detach().clone().float())
                model_params.append(p)

        options = {
            "lr": 1,
            "betas": (0.1, 0.25),
            "eps": 1e-08,
            "weight_decay": 0,
            "amsgrad": False,
        }
218

219
220
221
222
223
224
225
226
        ref_optim = torch.optim.Adam(ref_params, **options)
        tst_optim = te.optimizers.FusedAdam(
            model_params,
            master_weights=use_master_weights,
            master_weight_dtype=master_weight_dtype,
            exp_avg_dtype=exp_avg_dtype,
            exp_avg_sq_dtype=exp_avg_sq_dtype,
            use_decoupled_grad=True,
227
            store_param_remainders=store_param_remainders,
228
229
230
231
232
233
234
235
236
            **options,
        )

        def test_one_iteration(ref_optimizer, tst_optimizer):
            for p_ref, p in zip(ref_params, model_params):
                p_ref.grad = torch.rand_like(p_ref)
                p.decoupled_grad = p_ref.grad.clone().to(grad_dtype)
            ref_optimizer.step()
            tst_optimizer.step()
237
            if use_master_weights and not store_param_remainders:
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
                master_weights_to_fp32 = [
                    tst_optim.get_unscaled_state(p, "master_param") for p in model_params
                ]
                if not skip_assert:
                    torch.testing.assert_close(
                        ref_params,
                        master_weights_to_fp32,
                        rtol=master_rtol,
                        atol=master_atol,
                        equal_nan=True,
                    )
            ref_params_to_model_dtype = [p.to(param_dtype) for p in ref_params]
            if not skip_assert:
                torch.testing.assert_close(
                    ref_params_to_model_dtype,
                    model_params,
                    rtol=model_rtol,
                    atol=model_atol,
                    equal_nan=True,
                )

        for i in range(self.iters):
            test_one_iteration(ref_optim, tst_optim)

        state_dict = tst_optim.state_dict()
        tst_optim = te.optimizers.FusedAdam(
            model_params,
            master_weights=use_master_weights,
            master_weight_dtype=master_weight_dtype,
            exp_avg_dtype=exp_avg_dtype,
            exp_avg_sq_dtype=exp_avg_sq_dtype,
            use_decoupled_grad=True,
270
            store_param_remainders=store_param_remainders,
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
            **options,
        )
        tst_optim.load_state_dict(state_dict)

        for i in range(self.iters):
            test_one_iteration(ref_optim, tst_optim)

    def test_fp32_no_master(self):
        self.gen_precision_aware_test(
            use_fp8_params=False,
            param_dtype=torch.float32,
            use_master_weights=False,
            master_weight_dtype=torch.float32,
            grad_dtype=torch.float32,
            exp_avg_dtype=torch.float32,
            exp_avg_sq_dtype=torch.float32,
        )

    @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
    def test_fp32_master(self):
        self.gen_precision_aware_test(
            use_fp8_params=False,
            param_dtype=torch.bfloat16,
            use_master_weights=True,
            master_weight_dtype=torch.float32,
            grad_dtype=torch.float32,
            exp_avg_dtype=torch.float32,
            exp_avg_sq_dtype=torch.float32,
        )

301
302
303
304
305
306
307
308
309
310
311
312
313
    @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
    def test_fp32_master_store_param_remainders(self):
        self.gen_precision_aware_test(
            use_fp8_params=False,
            param_dtype=torch.bfloat16,
            use_master_weights=True,
            master_weight_dtype=torch.float32,
            grad_dtype=torch.float32,
            exp_avg_dtype=torch.float32,
            exp_avg_sq_dtype=torch.float32,
            store_param_remainders=True,
        )

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
    def test_fp16_master(self):
        self.gen_precision_aware_test(
            use_fp8_params=False,
            param_dtype=torch.bfloat16,
            use_master_weights=True,
            master_weight_dtype=torch.half,
            grad_dtype=torch.float32,
            exp_avg_dtype=torch.float32,
            exp_avg_sq_dtype=torch.float32,
            master_rtol=2e-3,
            master_atol=2e-3,
        )

    @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
    def test_bf16_grad(self):
        self.gen_precision_aware_test(
            use_fp8_params=False,
            param_dtype=torch.bfloat16,
            use_master_weights=True,
            master_weight_dtype=torch.float32,
            grad_dtype=torch.bfloat16,
            exp_avg_dtype=torch.float32,
            exp_avg_sq_dtype=torch.float32,
            master_rtol=2e-3,
            master_atol=2e-3,
        )

    @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
    def test_fp16_exp_avg(self):
        self.gen_precision_aware_test(
            use_fp8_params=False,
            param_dtype=torch.bfloat16,
            use_master_weights=True,
            master_weight_dtype=torch.float32,
            grad_dtype=torch.float32,
            exp_avg_dtype=torch.half,
            exp_avg_sq_dtype=torch.float32,
            master_rtol=2e-3,
            master_atol=2e-3,
        )

356
357
358
359
360
361
362
363
364
365
366
367
368
369
    @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
    def test_bf16_exp_avg(self):
        self.gen_precision_aware_test(
            use_fp8_params=False,
            param_dtype=torch.bfloat16,
            use_master_weights=True,
            master_weight_dtype=torch.float32,
            grad_dtype=torch.float32,
            exp_avg_dtype=torch.bfloat16,
            exp_avg_sq_dtype=torch.float32,
            master_rtol=2e-3,
            master_atol=2e-3,
        )

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
    @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
    def test_fp8_exp_avg(self):
        self.gen_precision_aware_test(
            use_fp8_params=False,
            param_dtype=torch.bfloat16,
            use_master_weights=True,
            master_weight_dtype=torch.float32,
            grad_dtype=torch.float32,
            exp_avg_dtype=torch.uint8,
            exp_avg_sq_dtype=torch.float32,
            master_rtol=1e-2,
            master_atol=1e-2,
        )

    @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
    def test_fp16_exp_avg_sq(self):
        self.gen_precision_aware_test(
            use_fp8_params=False,
            param_dtype=torch.bfloat16,
            use_master_weights=True,
            master_weight_dtype=torch.float32,
            grad_dtype=torch.float32,
            exp_avg_dtype=torch.float32,
            exp_avg_sq_dtype=torch.half,
            master_rtol=2e-3,
            master_atol=2e-3,
        )

399
400
401
402
403
404
405
406
407
408
409
410
411
412
    @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
    def test_bf16_exp_avg_sq(self):
        self.gen_precision_aware_test(
            use_fp8_params=False,
            param_dtype=torch.bfloat16,
            use_master_weights=True,
            master_weight_dtype=torch.float32,
            grad_dtype=torch.float32,
            exp_avg_dtype=torch.float32,
            exp_avg_sq_dtype=torch.bfloat16,
            master_rtol=2e-3,
            master_atol=2e-3,
        )

413
414
415
416
417
418
419
420
421
422
423
424
425
426
    @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
    @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
    def test_fp8_exp_avg_sq(self):
        self.gen_precision_aware_test(
            use_fp8_params=False,
            param_dtype=torch.bfloat16,
            use_master_weights=True,
            master_weight_dtype=torch.float32,
            grad_dtype=torch.float32,
            exp_avg_dtype=torch.float32,
            exp_avg_sq_dtype=torch.uint8,
            skip_assert=True,
        )

427
    @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    def test_bf16_model_weight_cast(self):
        dtype = torch.bfloat16
        model = MultiheadAttention(
            hidden_size=1024,
            num_attention_heads=16,
            layer_number=1,
            params_dtype=dtype,
            fuse_qkv_params=True,
        ).cuda()
        ref_params = []
        model_params = []
        for p in model.parameters():
            if p.requires_grad:
                ref_params.append(p.detach().clone().float())
                model_params.append(p)
        options = {
            "lr": 5e-4,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "weight_decay": 0,
            "amsgrad": False,
        }
        ref_optim = torch.optim.Adam(ref_params, **options)
451
452
453
        tst_optim = te.optimizers.FusedAdam(
            model_params, master_weights=True, use_decoupled_grad=True, **options
        )
454
455

        for i in range(self.iters):
456
457
458
            for p_ref, p in zip(ref_params, model_params):
                p_ref.grad = torch.rand_like(p_ref)
                p.decoupled_grad = p_ref.grad.clone()
459
460
            ref_optim.step()
            tst_optim.step()
461
            master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
462
463
464
465
466
467
            torch.testing.assert_close(ref_params, master_params)
            model_params_to_fp32 = [p.float() for p in model_params]
            torch.testing.assert_close(
                ref_params, model_params_to_fp32, rtol=1e-3, atol=1e-3, equal_nan=True
            )

468
    @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
469
470
    def test_fp8_model_weight_cast(self):
        dtype = torch.bfloat16
471
        with fp8_model_init(enabled=True, recipe=DelayedScaling()):
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
            model = MultiheadAttention(
                hidden_size=1024,
                num_attention_heads=16,
                layer_number=1,
                params_dtype=dtype,
                fuse_qkv_params=True,
            ).cuda()
        ref_params = []
        model_params = []
        for p in model.parameters():
            if p.requires_grad:
                ref_params.append(p.detach().clone().float())
                model_params.append(p)
        options = {
            "lr": 5e-4,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "weight_decay": 0,
            "amsgrad": False,
        }
        ref_optim = torch.optim.Adam(ref_params, **options)
493
494
495
        tst_optim = te.optimizers.FusedAdam(
            model_params, master_weights=True, use_decoupled_grad=True, **options
        )
496
497

        for i in range(self.iters):
498
499
500
            for p_ref, p in zip(ref_params, model_params):
                p_ref.grad = torch.rand_like(p_ref)
                p.decoupled_grad = p_ref.grad.clone()
501
502
            ref_optim.step()
            tst_optim.step()
503
            master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
504
505
506
507
508
509
            torch.testing.assert_close(ref_params, master_params)
            model_params_to_fp32 = [p.float() for p in model_params]
            torch.testing.assert_close(
                ref_params, model_params_to_fp32, rtol=1e-2, atol=1e-2, equal_nan=True
            )

510

511
class TestFusedSGD(TestFusedOptimizer):
512
513
514

    def setup_method(self) -> None:
        super().setup_method()
515
        self.options = {"lr": 0.25, "momentum": 0.125}
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        self.ref_optim = torch.optim.SGD
        self.fused_optim = te.optimizers.FusedSGD

    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)

    def test_half(self):
        self.gen_single_type_test(param_type=torch.float16)


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.reshape(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y


559
560
561
class AdamTest:

    def setup_method(self, *, seed: int = 0) -> None:
562
563
564
565
566
567
568
569
570
571
        torch.manual_seed(seed)

        self.model = Model().cuda()
        self.model_ = Model().cuda()
        self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))

        self.lr = 0.00001
        params = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = torch.optim.Adam(params, lr=self.lr)

572
    def test_grad_scaler(self):
573
574
575
576
577
578
579
580
581
582
583
584
        params_ = [p for p in self.model_.parameters() if p.requires_grad]
        optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)
        scaler = torch.cuda.amp.GradScaler(enabled=True)
        scaler_ = torch.cuda.amp.GradScaler(enabled=True)

        for i in range(100):
            x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
            x_ = x.clone()
            gt = torch.rand([32, 10]).cuda()
            gt_ = gt.clone()

            # Reference
585
            with gpu_autocast_ctx(enabled=True):
586
587
588
589
590
591
592
593
                y = self.model(x)
                loss = ((gt - y) ** 2).mean()

            scaler.scale(loss).backward()
            scaler.step(self.optimizer)
            scaler.update()

            # DUT
594
            with gpu_autocast_ctx(enabled=True):
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
                y = self.model_(x)
                loss_ = ((gt_ - y) ** 2).mean()

            scaler_.scale(loss_).backward()
            scaler_.step(optimizer_)
            scaler_.update()

            for module in zip(self.model.modules(), self.model_.modules()):
                m = module[0]
                m_ = module[1]
                if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
                    torch.testing.assert_close(
                        m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True
                    )
                    torch.testing.assert_close(
                        m.weight.grad,
                        m_.weight.grad,
                        atol=1e-3,
                        rtol=1e-3,
                        equal_nan=True,
                    )

            # Init for next iteration
            self.optimizer.zero_grad()
            optimizer_.zero_grad()

            self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))

623
    def test_grad_scaler_capturable(self):
624
625
626
627
628
629
630
631
632
633
634
635
        params_ = [p for p in self.model_.parameters() if p.requires_grad]
        optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=True)
        scaler = torch.cuda.amp.GradScaler(enabled=True)
        scaler_ = torch.cuda.amp.GradScaler(enabled=True)

        for i in range(100):
            x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
            x_ = x.clone()
            gt = torch.rand([32, 10]).cuda()
            gt_ = gt.clone()

            # Reference
636
            with gpu_autocast_ctx(enabled=True):
637
638
639
640
641
642
643
644
                y = self.model(x)
                loss = ((gt - y) ** 2).mean()

            scaler.scale(loss).backward()
            scaler.step(self.optimizer)
            scaler.update()

            # DUT
645
            with gpu_autocast_ctx(enabled=True):
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
                y = self.model_(x)
                loss_ = ((gt_ - y) ** 2).mean()

            scaler_.scale(loss_).backward()
            scaler_.step(optimizer_)
            scaler_.update()

            for module in zip(self.model.modules(), self.model_.modules()):
                m = module[0]
                m_ = module[1]
                if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
                    torch.testing.assert_close(
                        m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True
                    )
                    torch.testing.assert_close(
                        m.weight.grad,
                        m_.weight.grad,
                        atol=1e-3,
                        rtol=1e-3,
                        equal_nan=True,
                    )

            # Init for next iteration
            self.optimizer.zero_grad()
            optimizer_.zero_grad()

            self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))

674
    def test_grad_scaler_capturable_master(self):
675
676
677
678
679
        # Cast conv layers to FP16
        for m in self.model_.modules():
            if m.__class__ in [torch.nn.Conv2d]:
                m.half()
        params_ = [p for p in self.model_.parameters() if p.requires_grad]
680
        master_weights = [p.float() for p in self.model_.parameters() if p.requires_grad]
681
        optimizer_ = te.optimizers.FusedAdam(
682
            params_, lr=self.lr, capturable=True, master_weights=master_weights
683
684
685
686
687
688
689
690
691
692
693
        )
        scaler = torch.cuda.amp.GradScaler(enabled=True)
        scaler_ = torch.cuda.amp.GradScaler(enabled=True)

        for i in range(100):
            x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
            x_ = x.clone()
            gt = torch.rand([32, 10]).cuda()
            gt_ = gt.clone()

            # Reference
694
            with gpu_autocast_ctx(enabled=True):
695
696
697
698
699
700
701
702
                y = self.model(x)
                loss = ((gt - y) ** 2).mean()

            scaler.scale(loss).backward()
            scaler.step(self.optimizer)
            scaler.update()

            # DUT
703
            with gpu_autocast_ctx(enabled=True):
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
                y = self.model_(x)
                loss_ = ((gt_ - y) ** 2).mean()

            scaler_.scale(loss_).backward()
            scaler_.step(optimizer_)
            scaler_.update()

            for module in zip(self.model.modules(), self.model_.modules()):
                m = module[0]
                m_ = module[1]
                if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
                    torch.testing.assert_close(
                        m.weight,
                        m_.weight.float(),
                        atol=1e-3,
                        rtol=1e-3,
                        equal_nan=True,
                    )
                    torch.testing.assert_close(
                        m.weight.grad,
                        m_.weight.grad.float(),
                        atol=1e-3,
                        rtol=1e-3,
                        equal_nan=True,
                    )

            # Init for next iteration
            self.optimizer.zero_grad()
            optimizer_.zero_grad()

            self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))

736
    def test_native(self):
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
        params_ = [p for p in self.model_.parameters() if p.requires_grad]
        optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)

        for i in range(100):
            x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
            x_ = x.clone()
            gt = torch.rand([32, 10]).cuda()
            gt_ = gt.clone()

            # Reference
            y = self.model(x)
            loss = ((gt - y) ** 2).mean()

            loss.backward()
            self.optimizer.step()

            # DUT
            y = self.model_(x)
            loss_ = ((gt_ - y) ** 2).mean()

            loss_.backward()
            optimizer_.step()

            for module in zip(self.model.modules(), self.model_.modules()):
                m = module[0]
                m_ = module[1]
                if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
                    torch.testing.assert_close(
                        m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True
                    )
                    torch.testing.assert_close(
                        m.weight.grad,
                        m_.weight.grad,
                        atol=1e-3,
                        rtol=1e-3,
                        equal_nan=True,
                    )

            # Init for next iteration
            self.optimizer.zero_grad()
            optimizer_.zero_grad()

            self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))

    @largeTensorTest("60GB", "cuda")
782
    def test_large_tensor(self):
783
784
        t = torch.zeros(2359332864, dtype=torch.half, device="cuda")
        t2 = torch.zeros(2359332864, dtype=torch.half, device="cuda")
785
786
787
788
789
790
791
792
793
794
        grad = torch.randn_like(t)
        t.grad = grad
        t2.grad = grad
        params = [t]
        params2 = [t2]
        optimizer = te.optimizers.FusedAdam(params, lr=self.lr)
        optimizer.step()
        optimizer2 = torch.optim.Adam(params2, lr=self.lr)
        torch.testing.assert_close(t, t2)
        torch.cuda.synchronize()