test_optim.py 21 KB
Newer Older
1
2
import os
import unittest
3

4
import backend as F
5
6
7
import pytest
import torch as th
import torch.multiprocessing as mp
8
9

from dgl.nn import NodeEmbedding
10
11
from dgl.optim import SparseAdagrad, SparseAdam

12

13
14
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("emb_dim", [1, 4, 101, 1024])
15
def test_sparse_adam(emb_dim):
16
    num_embs = 10
17
18
    device = F.ctx()
    dgl_emb = NodeEmbedding(num_embs, emb_dim, "test")
19
20
21
22
    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
    th.manual_seed(0)
    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
    th.manual_seed(0)
23
    th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)
24
25
26
27
28
29

    dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
    torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)

    # first step
    idx = th.randint(0, num_embs, size=(4,))
30
    dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
31
    torch_value = torch_emb(idx)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    labels = th.zeros((4,)).long()
    print("dgl_value = {}".format(dgl_value))
    print("labels = {}".format(labels))

    dgl_adam.zero_grad()
    torch_adam.zero_grad()
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
    dgl_loss.backward()
    torch_loss.backward()

    dgl_adam.step()
    torch_adam.step()
    assert F.allclose(dgl_emb.weight, torch_emb.weight)

    # Can not test second step
    # Pytorch sparseAdam maintains a global step
    # DGL sparseAdam use a per embedding step

51
52
53
54

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("use_uva", [False, True, None])
@pytest.mark.parametrize("emb_dim", [1, 4, 101, 1024])
55
def test_sparse_adam_uva(use_uva, emb_dim):
56
    if F.ctx().type == "cpu" and use_uva == True:
57
58
59
60
        # we want to only test values of False and None when not using GPU
        pytest.skip("UVA cannot be used without GPUs.")

    num_embs = 10
61
62
    device = F.ctx()
    dgl_emb = NodeEmbedding(num_embs, emb_dim, "test_uva{}".format(use_uva))
63
64
65
66
67
68
69
70
71
72
73
    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
    th.manual_seed(0)
    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
    th.manual_seed(0)
    th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)

    dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01, use_uva=use_uva)
    torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)

    # first step
    idx = th.randint(0, num_embs, size=(4,))
74
    dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
75
76
    torch_value = torch_emb(idx)
    labels = th.zeros((4,)).long()
77
78
79
80
81
82
83
84
85
86

    dgl_adam.zero_grad()
    torch_adam.zero_grad()
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
    dgl_loss.backward()
    torch_loss.backward()

    dgl_adam.step()
    torch_adam.step()
87
    assert F.allclose(dgl_emb.weight, torch_emb.weight)
88
89
90
91
92

    # Can not test second step
    # Pytorch sparseAdam maintains a global step
    # DGL sparseAdam use a per embedding step

93
94
95
96

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("dtype", [th.float32, th.float16])
@pytest.mark.parametrize("emb_dim", [1, 4, 101, 1024])
97
98
def test_sparse_adam_dtype(dtype, emb_dim):
    num_embs = 10
99
100
    device = F.ctx()
    dgl_emb = NodeEmbedding(num_embs, emb_dim, "test_dtype{}".format(dtype))
101
102
103
104
105
106
107
108
109
110
111
    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
    th.manual_seed(0)
    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
    th.manual_seed(0)
    th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)

    dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01, dtype=dtype)
    torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)

    # first step
    idx = th.randint(0, num_embs, size=(4,))
112
    dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    torch_value = torch_emb(idx)
    labels = th.zeros((4,)).long()

    dgl_adam.zero_grad()
    torch_adam.zero_grad()
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
    dgl_loss.backward()
    torch_loss.backward()

    dgl_adam.step()
    torch_adam.step()
    assert F.allclose(dgl_emb.weight, torch_emb.weight)

    # Can not test second step
    # Pytorch sparseAdam maintains a global step
    # DGL sparseAdam use a per embedding step


132
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
133
134
135
def test_sparse_adam_zero_step():
    num_embs = 10
    emb_dim = 4
136
137
    device = F.ctx()
    dgl_emb = NodeEmbedding(num_embs, emb_dim, "test")
138
    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
139
    dgl_emb_zero = NodeEmbedding(num_embs, emb_dim, "test2")
140
141
142
143
144
145
146
147
148
149
    torch_emb_zero = th.nn.Embedding(num_embs, emb_dim, sparse=True)
    th.manual_seed(0)
    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
    th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0)
    th.manual_seed(0)
    th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)
    th.nn.init.uniform_(dgl_emb_zero.weight, 0, 1.0)

    dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
    torch_adam = th.optim.SparseAdam(
150
151
152
        list(torch_emb.parameters()) + list(torch_emb_zero.parameters()),
        lr=0.01,
    )
153
154
155

    # first step
    idx = th.randint(0, num_embs, size=(4,))
156
    dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    torch_value = torch_emb(idx)
    labels = th.ones((4,)).long()

    dgl_adam.zero_grad()
    torch_adam.zero_grad()
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
    dgl_loss.backward()
    torch_loss.backward()

    dgl_adam.step()
    torch_adam.step()
    assert F.allclose(dgl_emb.weight, torch_emb.weight)

171

172
173
174
175
176
def initializer(emb):
    th.manual_seed(0)
    emb.uniform_(-1.0, 1.0)
    return emb

177

178
179
180
181
182
183
184
185
186
187
def start_sparse_adam_worker(
    rank,
    device,
    world_size,
    weight,
    tensor_dev="cpu",
    has_zero_grad=False,
    backend="gloo",
    num_embs=128,
    emb_dim=10,
188
    zero_comm=True,
189
190
191
192
193
194
195
):
    print("start sparse worker for adam {}".format(rank))
    dist_init_method = "tcp://{master_ip}:{master_port}".format(
        master_ip="127.0.0.1", master_port="12345"
    )

    if device.type == "cuda":
196
197
        th.cuda.set_device(device)

198
199
200
201
202
203
    th.distributed.init_process_group(
        backend=backend,
        init_method=dist_init_method,
        world_size=world_size,
        rank=rank,
    )
204
205
206
207

    init_weight = th.empty((num_embs, emb_dim))
    th.manual_seed(0)
    th.nn.init.uniform_(init_weight, -1.0, 1.0)
208
209
210
    dgl_emb = NodeEmbedding(
        num_embs, emb_dim, "test", init_func=initializer, device=tensor_dev
    )
211
212
213
    dgl_emb.all_set_embedding(init_weight)

    if has_zero_grad:
214
215
216
        dgl_emb_zero = NodeEmbedding(
            num_embs, emb_dim, "zero", init_func=initializer, device=tensor_dev
        )
217
218
219
220
221
        dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
    else:
        dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)

    th.manual_seed(rank)
222
223
224
225
226
227
    if zero_comm:
        start = (num_embs // world_size) * rank
        end = (num_embs // world_size) * (rank + 1)
        idx = th.randint(start, end, size=(4,)).to(tensor_dev)
    else:
        idx = th.randint(0, num_embs, size=(4,)).to(tensor_dev)
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    dgl_value = dgl_emb(idx, device)
    labels = th.ones((4,)).long().to(device)
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    dgl_adam.zero_grad()
    dgl_loss.backward()
    dgl_adam.step()
    th.distributed.barrier()
    dgl_weight = dgl_emb.all_get_embedding().detach()
    after_step = dgl_emb(idx, device).cpu()

    if rank == 0:
        dgl_value = dgl_value.detach().cpu()
        assert F.allclose(dgl_value, after_step) is False
        weight[:] = dgl_weight[:]
    th.distributed.barrier()

244

245
def start_torch_adam_worker(
246
247
248
249
250
251
252
    rank,
    world_size,
    weight,
    has_zero_grad=False,
    num_embs=128,
    emb_dim=10,
    zero_comm=True,
253
254
255
256
257
258
259
260
261
262
263
264
265
):
    print("start sparse worker for adam {}".format(rank))
    dist_init_method = "tcp://{master_ip}:{master_port}".format(
        master_ip="127.0.0.1", master_port="12345"
    )
    backend = "gloo"

    th.distributed.init_process_group(
        backend=backend,
        init_method=dist_init_method,
        world_size=world_size,
        rank=rank,
    )
266
267
268
269
270
271
272

    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
    th.manual_seed(0)
    th.nn.init.uniform_(torch_emb.weight, -1.0, 1.0)
    torch_emb = th.nn.parallel.DistributedDataParallel(torch_emb)
    if has_zero_grad:
        torch_emb_zero = th.nn.Embedding(num_embs, emb_dim, sparse=True)
273
        torch_emb_zero = torch_emb_zero.to(tensor_dev)
274
275
276
277
        th.manual_seed(0)
        th.nn.init.uniform_(torch_emb_zero.weight, -1.0, 1.0)
        torch_emb_zero = th.nn.parallel.DistributedDataParallel(torch_emb_zero)
        torch_adam = th.optim.SparseAdam(
278
279
280
281
            list(torch_emb.module.parameters())
            + list(torch_emb_zero.module.parameters()),
            lr=0.01,
        )
282
    else:
283
284
285
        torch_adam = th.optim.SparseAdam(
            list(torch_emb.module.parameters()), lr=0.01
        )
286

287
    th.manual_seed(rank)
288
289
290
291
292
293
    if zero_comm:
        start = (num_embs // world_size) * rank
        end = (num_embs // world_size) * (rank + 1)
        idx = th.randint(start, end, size=(4,))
    else:
        idx = th.randint(0, num_embs, size=(4,))
294
    labels = th.ones((4,)).long()
295
    torch_value = torch_emb(idx)
296
297
298
299
    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
    torch_adam.zero_grad()
    torch_loss.backward()
    torch_adam.step()
300
301
    th.distributed.barrier()

302
    if rank == 0:
303
        weight[:] = torch_emb.module.weight.cpu()[:]
304
305
    th.distributed.barrier()

306
307
308

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(F.ctx().type != "cpu", reason="cpu only test")
309
310
@pytest.mark.parametrize("num_workers", [2, 4])
def test_multiprocess_cpu_sparse_adam(num_workers):
311
    backend = "gloo"
312
    worker_list = []
313
314
    num_embs = 128
    emb_dim = 10
315
    dgl_weight = th.empty((num_embs, emb_dim))
316
    ctx = mp.get_context("spawn")
317
318
    for i in range(num_workers):
        device = F.ctx()
319
320
321
322
323
324
325
326
327
328
329
330
        p = ctx.Process(
            target=start_sparse_adam_worker,
            args=(
                i,
                device,
                num_workers,
                dgl_weight,
                th.device("cpu"),
                True,
                backend,
            ),
        )
331
332
333
334
335
336
337
338
        p.start()
        worker_list.append(p)
    for p in worker_list:
        p.join()

    worker_list = []
    torch_weight = th.empty((num_embs, emb_dim))
    for i in range(num_workers):
339
340
341
342
        p = ctx.Process(
            target=start_torch_adam_worker,
            args=(i, num_workers, torch_weight, False),
        )
343
344
345
346
347
348
349
        p.start()
        worker_list.append(p)
    for p in worker_list:
        p.join()

    assert F.allclose(dgl_weight, torch_weight)

350
351
352

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(F.ctx().type == "cpu", reason="gpu only test")
353
@pytest.mark.parametrize("num_workers", [2, 4, 8])
354
@pytest.mark.parametrize("backend", ["nccl", "gloo"])
355
356
@pytest.mark.parametrize("zero_comm", [True, False])
def test_multiprocess_sparse_adam(num_workers, backend, zero_comm):
357
    if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
358
359
        pytest.skip("Not enough GPUs to run test.")

360
    worker_list = []
361
362
    num_embs = 128
    emb_dim = 10
363
    dgl_weight = th.empty((num_embs, emb_dim))
364
    ctx = mp.get_context("spawn")
365
366
    for i in range(num_workers):
        device = F.ctx()
367
        if device.type == "cuda":
368
369
            # make sure each process has a unique GPU
            device = th.device(i)
370
371
372
373
374
375
376
377
378
379
        p = ctx.Process(
            target=start_sparse_adam_worker,
            args=(
                i,
                device,
                num_workers,
                dgl_weight,
                th.device("cpu"),
                True,
                backend,
380
381
382
                num_embs,
                emb_dim,
                zero_comm,
383
384
            ),
        )
385
386
387
388
        p.start()
        worker_list.append(p)
    for p in worker_list:
        p.join()
389

390
391
392
    worker_list = []
    torch_weight = th.empty((num_embs, emb_dim))
    for i in range(num_workers):
393
394
        p = ctx.Process(
            target=start_torch_adam_worker,
395
396
397
398
399
400
401
402
403
            args=(
                i,
                num_workers,
                torch_weight,
                False,
                num_embs,
                emb_dim,
                zero_comm,
            ),
404
        )
405
406
407
408
409
410
411
        p.start()
        worker_list.append(p)
    for p in worker_list:
        p.join()

    assert F.allclose(dgl_weight, torch_weight)

412
413
414
415
416

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    F.ctx().type == "cpu", reason="cuda tensor is not supported for cpu"
)
417
418
@pytest.mark.parametrize("num_workers", [2, 4, 8])
def test_multiprocess_sparse_adam_cuda_tensor(num_workers):
419
    if F.ctx().type == "cpu":
420
        pytest.skip("Do not test CPU")
421
    if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
422
423
        pytest.skip("Not enough GPUs to run test.")

424
    backend = "nccl"
425
    worker_list = []
426
427
    num_embs = 128
    emb_dim = 10
428
    dgl_weight = th.empty((num_embs, emb_dim))
429
    ctx = mp.get_context("spawn")
430
    for i in range(num_workers):
431
        device = th.device(i)
432
433
434
435
        p = ctx.Process(
            target=start_sparse_adam_worker,
            args=(i, device, num_workers, dgl_weight, device, False, backend),
        )
436
437
        p.start()
        worker_list.append(p)
438
439
    for p in worker_list:
        p.join()
440

441
442
443
    worker_list = []
    torch_weight = th.empty((num_embs, emb_dim))
    for i in range(num_workers):
444
445
446
447
        p = ctx.Process(
            target=start_torch_adam_worker,
            args=(i, num_workers, torch_weight, False),
        )
448
449
        p.start()
        worker_list.append(p)
450
451
452
    for p in worker_list:
        p.join()

453
454
    assert F.allclose(dgl_weight, torch_weight)

455
456
457

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(F.ctx().type != "cpu", reason="cpu only test")
458
459
@pytest.mark.parametrize("num_workers", [2, 4])
def test_multiprocess_sparse_adam_cpu_zero_step(num_workers):
460
    backend = "gloo"
461
462

    worker_list = []
463
464
    num_embs = 128
    emb_dim = 10
465
    dgl_weight = th.empty((num_embs, emb_dim))
466
    ctx = mp.get_context("spawn")
467
468
    for i in range(num_workers):
        device = F.ctx()
469
470
471
472
473
474
475
476
477
478
479
480
        p = ctx.Process(
            target=start_sparse_adam_worker,
            args=(
                i,
                device,
                num_workers,
                dgl_weight,
                th.device("cpu"),
                True,
                backend,
            ),
        )
481
482
483
484
485
486
487
488
        p.start()
        worker_list.append(p)
    for p in worker_list:
        p.join()

    worker_list = []
    torch_weight = th.empty((num_embs, emb_dim))
    for i in range(num_workers):
489
490
491
492
        p = ctx.Process(
            target=start_torch_adam_worker,
            args=(i, num_workers, torch_weight, False),
        )
493
494
495
496
497
498
499
        p.start()
        worker_list.append(p)
    for p in worker_list:
        p.join()

    assert F.allclose(dgl_weight, torch_weight)

500
501
502

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(F.ctx().type == "cpu", reason="gpu only test")
503
@pytest.mark.parametrize("num_workers", [2, 4, 8])
504
@pytest.mark.parametrize("backend", ["nccl", "gloo"])
505
def test_multiprocess_sparse_adam_zero_step(num_workers, backend):
506
    if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
507
508
        pytest.skip("Not enough GPUs to run test.")

509
    worker_list = []
510
511
    num_embs = 128
    emb_dim = 10
512
    dgl_weight = th.empty((num_embs, emb_dim))
513
    ctx = mp.get_context("spawn")
514
515
    for i in range(num_workers):
        device = F.ctx()
516
        if device.type == "cuda":
517
518
            # make sure each process has a unique GPU
            device = th.device(i)
519
520
521
522
523
524
525
526
527
528
529
530
        p = ctx.Process(
            target=start_sparse_adam_worker,
            args=(
                i,
                device,
                num_workers,
                dgl_weight,
                th.device("cpu"),
                True,
                backend,
            ),
        )
531
532
533
534
        p.start()
        worker_list.append(p)
    for p in worker_list:
        p.join()
535

536
537
538
    worker_list = []
    torch_weight = th.empty((num_embs, emb_dim))
    for i in range(num_workers):
539
540
541
542
        p = ctx.Process(
            target=start_torch_adam_worker,
            args=(i, num_workers, torch_weight, False),
        )
543
544
545
546
547
548
549
        p.start()
        worker_list.append(p)
    for p in worker_list:
        p.join()

    assert F.allclose(dgl_weight, torch_weight)

550
551
552
553
554

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(
    F.ctx().type == "cpu", reason="cuda tensor is not supported for cpu"
)
555
556
@pytest.mark.parametrize("num_workers", [2, 4, 8])
def test_multiprocess_sparse_adam_zero_step_cuda_tensor(num_workers):
557
    if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
558
559
        pytest.skip("Not enough GPUs to run test.")

560
    backend = "nccl"
561
    worker_list = []
562
563
    num_embs = 128
    emb_dim = 10
564
    dgl_weight = th.empty((num_embs, emb_dim))
565
    ctx = mp.get_context("spawn")
566
    for i in range(num_workers):
567
        device = th.device(i)
568
569
570
571
        p = ctx.Process(
            target=start_sparse_adam_worker,
            args=(i, device, num_workers, dgl_weight, device, True, backend),
        )
572
573
        p.start()
        worker_list.append(p)
574
575
    for p in worker_list:
        p.join()
576

577
578
579
    worker_list = []
    torch_weight = th.empty((num_embs, emb_dim))
    for i in range(num_workers):
580
581
582
583
        p = ctx.Process(
            target=start_torch_adam_worker,
            args=(i, num_workers, torch_weight, False),
        )
584
585
        p.start()
        worker_list.append(p)
586
587
588
    for p in worker_list:
        p.join()

589
590
    assert F.allclose(dgl_weight, torch_weight)

591

592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
def start_sparse_adam_state_dict_worker(
    rank,
    world_size,
    init_weight,
    backend,
    num_embs,
    emb_dim,
):
    print("start sparse worker for adam {}".format(rank))
    dist_init_method = "tcp://{master_ip}:{master_port}".format(
        master_ip="127.0.0.1", master_port="12345"
    )

    device = th.device(f"cuda:{rank}")
    th.cuda.set_device(device)
    tensor_dev = device if backend == "nccl" else th.device("cpu")

    th.distributed.init_process_group(
        backend=backend,
        init_method=dist_init_method,
        world_size=world_size,
        rank=rank,
    )

    th.manual_seed(0)
    dgl_emb = NodeEmbedding(
        num_embs, emb_dim, "test", init_func=initializer, device=tensor_dev
    )
    dgl_emb.all_set_embedding(init_weight)

    dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)

    start = (num_embs // world_size) * rank
    end = (num_embs // world_size) * (rank + 1)
    th.manual_seed(rank)
    idx = th.randint(start, end, size=(4,)).to(tensor_dev)
    dgl_value = dgl_emb(idx, device)
    labels = th.ones((4,)).long().to(device)
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    dgl_adam.zero_grad()
    dgl_loss.backward()
    dgl_adam.step()
    th.distributed.barrier()

    worker_state_dict = [t.detach().clone() for t in dgl_emb.optm_state]
    state_dict = dgl_adam.state_dict()
    for t in dgl_emb.optm_state:
        t.zero_()
    dgl_adam.load_state_dict(state_dict)

    for i, j in zip(worker_state_dict, dgl_emb.optm_state):
        F.allclose(i, j)

    th.distributed.barrier()


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@unittest.skipIf(F.ctx().type == "cpu", reason="gpu only test")
@pytest.mark.parametrize("num_workers", [1, 2, 4, 8])
@pytest.mark.parametrize("backend", ["nccl", "gloo"])
def test_multiprocess_sparse_adam_state_dict(num_workers, backend):
    if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
        pytest.skip("Not enough GPUs to run test.")

    num_embs = 128
    emb_dim = 10
    init_weight = th.rand((num_embs, emb_dim))
    mp.spawn(
        start_sparse_adam_state_dict_worker,
        (
            num_workers,
            init_weight,
            backend,
            num_embs,
            emb_dim,
        ),
        nprocs=num_workers,
    )


672
if __name__ == "__main__":
673
674
675
676
    test_sparse_adam(1)
    test_sparse_adam(4)
    test_sparse_adam(101)
    test_sparse_adam(1024)
677
678
    test_sparse_adam_zero_step()

679
680
681
682
683
    test_multiprocess_cpu_sparse_adam(2)
    test_multiprocess_cpu_sparse_adam(4)
    test_multiprocess_cpu_sparse_adam(8)
    test_multiprocess_sparse_adam_cpu_zero_step(2)

684
685
686
687
688
689
    test_multiprocess_sparse_adam(2, backend="gloo")
    test_multiprocess_sparse_adam(4, backend="gloo")
    test_multiprocess_sparse_adam(8, backend="gloo")
    test_multiprocess_sparse_adam(2, backend="nccl")
    test_multiprocess_sparse_adam(4, backend="nccl")
    test_multiprocess_sparse_adam(8, backend="nccl")
690

691
692
    test_multiprocess_sparse_adam_zero_step(2, backend="gloo")
    test_multiprocess_sparse_adam_zero_step(4, backend="nccl")
693

694
695
    test_multiprocess_sparse_adam_cuda_tensor(2)
    test_multiprocess_sparse_adam_zero_step_cuda_tensor(4)
696
697
698

    test_multiprocess_sparse_adam_state_dict(2, "nccl")
    test_multiprocess_sparse_adam_state_dict(2, "gloo")