test_numerical.py 17.3 KB
Newer Older
1
import sys
2
from collections import OrderedDict
3
from typing import List, Type, Union
Sengxian's avatar
Sengxian committed
4

5
import pytest
Rick Ho's avatar
Rick Ho committed
6
import torch
7
import torch.nn as nn
8
import numpy as np
9

10
from copy import deepcopy
11
12
13
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
14
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
Rick Ho's avatar
Rick Ho committed
15
from fmoe.megatron.layers import _megatron_init_method
16
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
17
18


19
def _perform_forward(
20
    moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group, data_type='torch.FloatTensor'
21
):
22
23
    moe.zero_grad()
    moe_raw.zero_grad()
24
25
26

    inp = torch.rand(batch_size, d_model).type(data_type).cuda()
        
Rick Ho's avatar
Rick Ho committed
27
    if mp_group is not None:
28
29
30
31
32
33
34
35
36
        group_sender = rank // mp_group.size() * mp_group.size()
        torch.distributed.broadcast(inp, group_sender, group=mp_group)
        torch.distributed.broadcast(
            moe.gate.gate.weight.data, group_sender, group=mp_group
        )
        torch.distributed.broadcast(
            moe.gate.gate.bias.data, group_sender, group=mp_group
        )

37
38
    inp_raw = inp.clone()
    inp.requires_grad = True
39

40
    inp_raw.requires_grad = True
Rick Ho's avatar
Rick Ho committed
41
    gate_idx, gate_score = moe.gate(inp_raw)
42
    moe_out = moe(inp)
Rick Ho's avatar
Rick Ho committed
43
    raw_out = moe_raw(inp_raw, gate_idx, gate_score)
44

45
46
47
48
    raw_out.mean().backward()
    moe_out.mean().backward()

    return moe_out, raw_out, inp.grad, inp_raw.grad
49
50


51
def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
52
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
Rick Ho's avatar
Rick Ho committed
53
54
55
56
        if mo is None and ro is None:
            continue
        if mo is None or ro is None:
            assert False
57
        err = (mo - ro).abs().max()
Rick Ho's avatar
Rick Ho committed
58
        if err.dtype == torch.bfloat16 or err.dtype == torch.float16:
Rick Ho's avatar
Rick Ho committed
59
            precision *= 100
60
        print("Rank {} {} abs err {}".format(rank, name, err))
61
        if err > precision:
Sengxian's avatar
Sengxian committed
62
            sys.stderr.write(f"=========== {name} moe out ==============\n")
63
            sys.stderr.write("{}\n".format(mo))
Sengxian's avatar
Sengxian committed
64
            sys.stderr.write(f"=========== {name} raw out ==============\n")
65
            sys.stderr.write("{}\n".format(ro))
66
67
            sys.stderr.write(f"=========== {name} diff ==============\n")
            sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err))
68
69
70
            assert False


71
class MyMoE(FMoE):
72
73
74
    def __init__(
        self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ):
75
76
77
78
79
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
            gate=NaiveGate,
            world_size=world_size,
80
            slice_group=mp_group,
81
            top_k=top_k,
82
83
84
        )
        self.experts = _Expert(num_expert, d_model, d_hidden, activation)

85
        rng = np.random.default_rng(1234)
Sengxian's avatar
Sengxian committed
86
87
        _megatron_init_method(self.experts.htoh4, rng, 1.0)
        _megatron_init_method(self.experts.h4toh, rng, 1.0)
88

89

90
@pytest.mark.parametrize("num_expert", [4, 8])
Sengxian's avatar
Sengxian committed
91
@pytest.mark.parametrize("top_k", [2, 3])
92
93
94
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
95
96
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
97
@pytest.mark.parametrize("mp_group", [None])
98
99
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
Rick Ho's avatar
Rick Ho committed
100
@pytest.mark.parametrize("data_type", ['torch.float32', 'torch.bfloat16', 'torch.float16'])
101
102
103
104
105
106
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
107
108
    rank,
    world_size,
109
    mp_group,
110
111
    dp_group,
    world_group,
112
    data_type,
113
114
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
115
116
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
117

Rick Ho's avatar
Rick Ho committed
118
119
120
    if isinstance(data_type, str):
        data_type = eval(data_type)

121
122
    moe = MyMoE(
        num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
123
    ).type(data_type).cuda()
Rick Ho's avatar
Rick Ho committed
124

Sengxian's avatar
Sengxian committed
125
126
127
128
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
129
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
130
        world_size=world_size,
Sengxian's avatar
Sengxian committed
131
        top_k=top_k,
132
    ).type(data_type).cuda()
Rick Ho's avatar
Rick Ho committed
133
134

    if world_size == 1:
135
136
137
138
        moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
        moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
        moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
        moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
Rick Ho's avatar
Rick Ho committed
139
    else:
Sengxian's avatar
Sengxian committed
140
        weight_htoh4_array = [
141
            torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
142
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
143
        bias_htoh4_array = [
144
            torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
145
        ]
146
147
        torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
        torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
Sengxian's avatar
Sengxian committed
148
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
149
        moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
Sengxian's avatar
Sengxian committed
150
151

        weight_h4toh_array = [
152
            torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
153
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
154
        bias_h4toh_array = [
155
            torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
156
        ]
157
158
        torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
        torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
Sengxian's avatar
Sengxian committed
159
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
160
        moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
Sengxian's avatar
Sengxian committed
161

162
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
163
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type=data_type
164
    )
Sengxian's avatar
Sengxian committed
165

Sengxian's avatar
Sengxian committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    moe_out_list = (
        moe_out,
        moe_grad_in,
        moe.experts.htoh4.weight.grad,
        moe.experts.h4toh.weight.grad,
        moe.experts.htoh4.bias.grad,
        moe.experts.h4toh.bias.grad,
    )
    raw_out_list = (
        raw_out,
        raw_grad_in,
        moe_raw.weight_htoh4.grad,
        moe_raw.weight_h4toh.grad,
        moe_raw.bias_htoh4.grad,
        moe_raw.bias_h4toh.grad,
    )
Sengxian's avatar
Sengxian committed
182

Rick Ho's avatar
Rick Ho committed
183
    if world_size > 1:
Sengxian's avatar
Sengxian committed
184
        _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
Jiezhong Qiu's avatar
Jiezhong Qiu committed
185
186
187
188
        torch.distributed.all_reduce(htoh4_w_grad)
        torch.distributed.all_reduce(h4toh_w_grad)
        torch.distributed.all_reduce(htoh4_b_grad)
        torch.distributed.all_reduce(h4toh_b_grad)
189
        mp_size = mp_group.size() if mp_group else 1
190
191
192
193
194
195
196
197
198
199
200
201
        htoh4_w_grad = (
            htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_w_grad = (
            h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        htoh4_b_grad = (
            htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_b_grad = (
            h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
Sengxian's avatar
Sengxian committed
202
        raw_out_list = _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
203

Sengxian's avatar
Sengxian committed
204
205
206
207
208
209
210
211
    names = [
        "output",
        "input grad",
        "htoh4 weight grad",
        "h4toh weight grad",
        "htoh4 bias grad",
        "h4toh bias grad",
    ]
Sengxian's avatar
Sengxian committed
212

213
214
215
216

    precision = 5e-1 if data_type == 'torch.HalfTensor' else 1e-3
        
    _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=precision)
217

Sengxian's avatar
Sengxian committed
218

219
220
221
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
Sengxian's avatar
Sengxian committed
222
@pytest.mark.parametrize("top_k", [2, 3])
223
@pytest.mark.parametrize("expert", [NaiveExpert, LinearExpert])
224
225
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
226
@pytest.mark.parametrize("mp_group", [None])
227
228
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
Rick Ho's avatar
Rick Ho committed
229
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16, torch.bfloat16])
230
def test_fmoe(
231
232
233
234
235
236
237
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
    world_size,
238
239
240
    mp_group,
    dp_group,
    world_group,
Rick Ho's avatar
Rick Ho committed
241
    data_type
242
243
244
245
246
247
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Rick Ho's avatar
Rick Ho committed
248
249
250
        assert(expert is not None)
    if isinstance(data_type, str):
        data_type = eval(data_type)
Sengxian's avatar
Sengxian committed
251

252
253
254
255
256
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
257
        mp_group=mp_group,
258
259
        expert=expert,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
260
    ).cuda().type(data_type)
261
262

    moe_raw = BruteForceMoE(
Sengxian's avatar
Sengxian committed
263
264
265
266
267
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
268
    ).cuda().type(data_type)
269
270
271
272
273
274
275
276
277
278

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
Rick Ho's avatar
Rick Ho committed
279
            assert(para.device.type == 'cuda')
280
281
282
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
Rick Ho's avatar
Rick Ho committed
283
            assert(para_tensor.device.type == 'cuda')
284
285
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
286
287
288
289
290
291
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
292

293
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
Rick Ho's avatar
Rick Ho committed
294
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type
295
    )
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
317
318
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
319

320
321
322
    moe_out_list = [moe_out, moe_grad, moe_grad_in]
    raw_out_list = [raw_out, raw_grad, raw_grad_in]
    names = ["forward", "backward", "grad_in"]
323

324
    _assert_numerical(names, moe_out_list, raw_out_list, rank)
Sengxian's avatar
Sengxian committed
325
326


327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
class MyModule(nn.Module):
    def __init__(self, dim=8):
        super(MyModule, self).__init__()
        self.model = nn.Sequential(
            OrderedDict(
                [
                    ("linear1", nn.Linear(dim, dim)),
                    ("relu1", nn.ReLU()),
                    ("linear2", nn.Linear(dim, dim)),
                    ("relu2", nn.ReLU()),
                    ("linear3", nn.Linear(dim, dim)),
                ]
            )
        )

    def set_comm(self):
        for p in self.model._modules["linear1"].parameters():
            setattr(p, "dp_comm", "mp")
        for p in self.model._modules["linear2"].parameters():
            setattr(p, "dp_comm", "dp")
        for p in self.model._modules["linear3"].parameters():
            setattr(p, "dp_comm", "world")

    def forward(self, inp):
        return self.model(inp)


def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
    batch_size, dim = 4, 8

    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    model = MyModule().cuda()
Rick Ho's avatar
Rick Ho committed
361
362
    model_ddp = LocalDDP(deepcopy(model),
            mp_group=mp_group, dp_group=dp_group, world_group=world_group)
363
    model = deepcopy(model_ddp.module)
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    model.set_comm()
    model_ddp.module.set_comm()

    inp = torch.randn(batch_size, dim).cuda()

    raw_out = model(inp).mean()
    ddp_out = model_ddp(inp).mean()

    raw_out.backward()
    ddp_out.backward()

    torch.distributed.all_reduce(
        model.model._modules["linear1"].weight.grad.data, group=mp_group
    )
    model.model._modules["linear1"].weight.grad /= mp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear2"].weight.grad.data, group=dp_group
    )
    model.model._modules["linear2"].weight.grad /= dp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear3"].weight.grad.data, group=world_group
    )
    model.model._modules["linear3"].weight.grad /= world_group.size()
    model_ddp.allreduce_params(reduce_after=False, fp32_allreduce=False)

    raw_out_list = [
        model.model._modules["linear1"].weight.grad,
        model.model._modules["linear2"].weight.grad,
        model.model._modules["linear3"].weight.grad,
    ]
    ddp_out_list = [
        model_ddp.module.model._modules["linear1"].weight.grad,
        model_ddp.module.model._modules["linear2"].weight.grad,
        model_ddp.module.model._modules["linear3"].weight.grad,
    ]

    names = ["mp grad", "dp grad", "wp grad"]

402
    _assert_numerical(names, ddp_out_list, raw_out_list, rank)
403
404


Colin's avatar
Colin committed
405
406
407
408
409
410
411
412
413
414
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [None])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("expert", [ [NaiveExpert for _ in range(4)], [LinearExpert, NaiveExpert, LinearExpert, NaiveExpert, LinearExpert, NaiveExpert, LinearExpert, NaiveExpert] ])
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
Rick Ho's avatar
Rick Ho committed
415
@pytest.mark.parametrize("data_type", [torch.float32])
Colin's avatar
Colin committed
416
417
418
419
420
421
422
423
424
425
426
def test_fmoe_experts(
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
    world_size,
    mp_group,
    dp_group,
    world_group,
Rick Ho's avatar
Rick Ho committed
427
    data_type
Colin's avatar
Colin committed
428
429
430
431
432
433
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Rick Ho's avatar
Rick Ho committed
434
435
    if isinstance(data_type, str):
        data_type = eval(data_type)
Colin's avatar
Colin committed
436
437
438
439
440
441
442
443
444

    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
        mp_group=mp_group,
        expert=expert,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
445
    ).cuda().type(data_type)
Colin's avatar
Colin committed
446
447
448
449
450
451
452

    moe_raw = BruteForceMoE(
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
453
    ).cuda().to(data_type)
Colin's avatar
Colin committed
454
455
456
457
458
459
460
461
462
463

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
Rick Ho's avatar
Rick Ho committed
464
465
            for ep in expert.parameters():
                assert(ep.device.type == 'cuda')
Colin's avatar
Colin committed
466
467
468
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
Rick Ho's avatar
Rick Ho committed
469
            assert(para_tensor.device.type == 'cuda')
Colin's avatar
Colin committed
470
471
472
473
474
475
476
477
478
479
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]

    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
Rick Ho's avatar
Rick Ho committed
480
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type
Colin's avatar
Colin committed
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
    )

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size

    moe_out_list = [moe_out, moe_grad, moe_grad_in]
    raw_out_list = [raw_out, raw_grad, raw_grad_in]
    names = ["forward", "backward", "grad_in"]

    _assert_numerical(names, moe_out_list, raw_out_list, rank)


513
if __name__ == "__main__":
Rick Ho's avatar
Rick Ho committed
514
    test_fmoe(
515
516
517
        batch_size=2,
        num_expert=2,
        d_model=2,
518
        top_k=2,
Rick Ho's avatar
Rick Ho committed
519
        expert=[NaiveExpert for _ in range(4)],
520
521
        rank=0,
        world_size=1,
522
        mp_group=None,
523
524
        dp_group=None,
        world_group=None,
Rick Ho's avatar
Rick Ho committed
525
        data_type=torch.bfloat16
526
    )