hybrid_parallel_plugin.py 25.2 KB
Newer Older
1
2
import random
from contextlib import nullcontext
3
from functools import partial
4
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
5
6
7
8
9

import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
10
11
from torch.nn import Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
12
13
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
14
from torch.utils._pytree import tree_map
15
16
17
18
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
19
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
20
21
22
23
24
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
25
from colossalai.shardformer.policies.base_policy import Policy
26
27
28
29
30
31
32
from colossalai.zero.low_level import LowLevelZeroOptimizer

from .pp_plugin_base import PipelinePluginBase

DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2


33
34
35
36
37
38
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
    if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
        return x.to(dtype)
    return x


39
class HybridParallelModule(ModelWrapper):
40
41
42
43
44
45
46
47
48
49
    def __init__(
        self,
        module: Module,
        precision: str,
        shard_config: ShardConfig,
        dp_group: ProcessGroup,
        use_ddp: bool,
        ddp_config: dict,
        custom_policy: Policy,
    ) -> None:
50
51
        self.stage_manager = shard_config.pipeline_stage_manager
        self.dp_group = dp_group
52

53
        shardformer = ShardFormer(shard_config)
54
55
56
        if custom_policy is not None:
            assert isinstance(custom_policy, object)
        module, self.shared_params = shardformer.optimize(module, policy=custom_policy)
57
58

        # setting process groups for shared parameters
59
60
61
        self.shared_param_process_groups = []
        for shared_param in self.shared_params:
            if len(shared_param) > 0:
62
                self.shared_param_process_groups.append(
63
64
                    self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
                )
65
66
67

        # setting mixed_precision
        self.mixed_precision = None
68
        if precision == "fp16":
69
            self.mixed_precision = torch.float16
70
        elif precision == "bf16":
71
72
73
74
            self.mixed_precision = torch.bfloat16
        if self.mixed_precision is not None:
            module = module.to(self.mixed_precision)
        module = module.cuda()
75

76
77
78
79
        # setting input type cast when using mixed precision
        self.convert_fn = None
        if self.mixed_precision is not None:
            self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision)
80

81
82
        # setting ddp configs
        if use_ddp:
83
84
85
86
87
            # convert model to sync bn
            module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
            # wrap the model with PyTorch DDP
            module = DDP(module, process_group=dp_group, **ddp_config)

88
89
90
91
        super().__init__(module)

    def sync_shared_params(self):
        for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
92
93
94
95
            if self.stage_manager.stage in shared_param:
                param = shared_param[self.stage_manager.stage]
                dist.all_reduce(param.grad, group=group)
            dist.barrier()
96
97
98
99
100
101
102
103
104
105
106
107

    def no_sync(self) -> Iterator[None]:
        # no sync grads across data parallel
        return nullcontext()

    def sync_grads(self):
        # sync grad across data parallel
        if self.dp_group.size() == 1:
            return
        for p in self.module.parameters():
            if p.grad is not None:
                dist.all_reduce(p.grad, group=self.dp_group)
108
                p.grad.div_(self.dp_group.size())
109

110
111
112
113
114
115
    def forward(self, *args, **kwargs):
        if self.convert_fn is not None:
            args = tree_map(self.convert_fn, args)
            kwargs = tree_map(self.convert_fn, kwargs)
        return super().forward(*args, **kwargs)

116
117
118
119
120
121
    def unwrap(self):
        module = super().unwrap()
        if isinstance(module, DDP):
            module = module.module
        return module

122

123
124
125
126
127
128
129
130
131
132
def get_param_info(optim: Optimizer):
    # Get a backup of necessary information of parameters for future use, which includes:
    # 1. A complete param_group, with params in the form of param_id
    # 2. A mapping from param address (obtained using id(param)) to integer param_id
    # 3. A mapping from integer param_id to param address.
    # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
    # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.

    if optim is None:
        return {}
133
    param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
134
135
    start_index = 0
    for group in optim.param_groups:
136
137
        packed_group = {k: v for k, v in group.items() if k != "params"}
        packed_group["params"] = []
138

139
        for param_id, param in enumerate(group["params"], start_index):
140
            original_shape = param.shape if isinstance(param, torch.Tensor) else None
141
142
143
144
            packed_group["params"].append(param_id)
            param_info["param2id"][id(param)] = param_id
            param_info["id2param"][param_id] = id(param)
            param_info["param2shape"][id(param)] = original_shape
145

146
147
        param_info["param_groups"].append(packed_group)
        start_index += len(group["params"])
148
149
150
151

    return param_info


152
def init_pipeline_optimizer(optim: Optimizer, model: Module):
153
    model_params = set(model.parameters())
154
155
    new_param_groups = []
    for group in optim.param_groups:
156
157
158
        params = [p for p in group["params"] if p in model_params]
        new_param_groups.append({**group, "params": params})
    optim.__setstate__({"param_groups": new_param_groups})
159
160


161
class HybridParallelNaiveOptimizer(OptimizerWrapper):
162
163
    def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
        self.param_info = param_info
164
165
166
167
168
169
        if use_pipeline:
            init_pipeline_optimizer(optim, model)
        super().__init__(optim)


class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    def __init__(
        self,
        optim: Optimizer,
        model: Module,
        use_pipeline: bool,
        param_info: OrderedDict,
        precision: str = "fp16",
        initial_scale: float = 2**16,
        min_scale: float = 1,
        growth_factor: float = 2,
        backoff_factor: float = 0.5,
        growth_interval: int = 1000,
        hysteresis: int = 2,
        max_scale: float = 2**32,
        max_norm: float = 0,
    ):
186
        self.param_info = param_info
187
188
        if use_pipeline:
            init_pipeline_optimizer(optim, model)
189
190
191
192
193
194
195
196
197
198
199
200
        super().__init__(
            optim,
            precision,
            initial_scale,
            min_scale,
            growth_factor,
            backoff_factor,
            growth_interval,
            hysteresis,
            max_scale,
            max_norm,
        )
201
202
203
204


class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
    def __init__(
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        self,
        optimizer: Optimizer,
        model: Module,
        use_pipeline: bool,
        param_info: OrderedDict,
        initial_scale: int = 2**16,  # grad scaler config
        min_scale: int = 1,
        growth_factor: float = 2.0,
        backoff_factor: float = 0.5,
        growth_interval: int = 2000,
        hysteresis: int = 2,
        max_scale: int = 2**24,
        clip_grad_norm: float = 0.0,  # grad clipping
        verbose: bool = False,
        reduce_bucket_size: int = 1024 * 1024,  # communication
        communication_dtype: Optional[torch.dtype] = None,
        overlap_communication: bool = True,
        partition_grad: bool = False,  # stage 2 flag
        cpu_offload: bool = False,  # cpu offload
        dp_process_group: Optional[ProcessGroup] = None,  # the dp pg for comm
        tp_process_group: Optional[ProcessGroup] = None,  # if using tp
        forced_dtype: Optional[torch.dtype] = None,
    ):
228
        self.param_info = param_info
229
230
        if use_pipeline:
            init_pipeline_optimizer(optimizer, model)
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        super().__init__(
            optimizer,
            initial_scale,
            min_scale,
            growth_factor,
            backoff_factor,
            growth_interval,
            hysteresis,
            max_scale,
            clip_grad_norm,
            verbose,
            reduce_bucket_size,
            communication_dtype,
            overlap_communication,
            partition_grad,
            cpu_offload,
            dp_process_group,
            tp_process_group,
            forced_dtype,
        )
251
252
253


class HybridParallelPlugin(PipelinePluginBase):
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    """
    Plugin for Hybrid Parallel Training.
    Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
    The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).

    Example:
        >>> from colossalai.booster import Booster
        >>> from colossalai.booster.plugin import HybridParallelPlugin

        >>> model, train_dataset, optimizer, criterion = ...
        >>> plugin =  HybridParallelPlugin(tp_size=2, pp_size=2)

        >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
        >>> booster = Booster(plugin=plugin)
        >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)

    Args:
        tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
        pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
        precision (str, optional): Specifies the precision of parameters during training.
                                    Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
                                    Defaults to 'fp16'.
        zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
                                        When set to 0, ZeRO will not be used. Defaults to 0.
        enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
                                                    Currently all the optimization methods include fused normalization, flash attention and JIT.
                                                    Defaults to False.
281
282
283
284
285
        enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
        enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
        enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
        enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
        enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
286
        num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
287
288
289
        microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
            Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
            If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
290
291
292
293
294
295
296
297
        initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
        min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
        growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
        backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
        growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
        hysteresis (int, optional):  The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
        max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
        max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
298
299
300
301
302
303
304
305
306
307
        broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
        ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
        find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
        check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
        gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
        static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
        zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
        cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
        communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
        overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
308
        custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
309
310
    """

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    def __init__(
        self,
        tp_size: int,
        pp_size: int,
        precision: str = "fp16",
        zero_stage: int = 0,
        enable_all_optimization: bool = False,
        enable_fused_normalization: bool = False,
        enable_flash_attention: bool = False,
        enable_jit_fused: bool = False,
        enable_sequence_parallelism: bool = False,
        enable_sequence_overlap: bool = False,
        num_microbatches: Optional[int] = None,
        microbatch_size: Optional[int] = None,
        initial_scale: float = 2**16,
        min_scale: float = 1,
        growth_factor: float = 2,
        backoff_factor: float = 0.5,
        growth_interval: int = 1000,
        hysteresis: int = 2,
        max_scale: float = 2**32,
        max_norm: float = 0,
        broadcast_buffers: bool = True,
        ddp_bucket_cap_mb: int = 25,
        find_unused_parameters: bool = False,
        check_reduction: bool = False,
        gradient_as_bucket_view: bool = False,
        static_graph: bool = False,
        zero_bucket_size_in_m: int = 12,
        cpu_offload: bool = False,
        communication_dtype: Optional[torch.dtype] = None,
        overlap_communication: bool = True,
        custom_policy: Policy = None,
    ) -> None:
345
        super().__init__()
346
347
348
        assert (
            dist.get_world_size() % (tp_size * pp_size) == 0
        ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
349
350

        if enable_sequence_parallelism:
351
            assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
352

353
354
355
356
357
358
        self.tp_size = tp_size
        self.pp_size = pp_size
        self.dp_size = dist.get_world_size() // (tp_size * pp_size)
        self.precision = precision
        self.zero_stage = zero_stage
        self.cpu_offload = cpu_offload
359
        self.enable_all_optimization = enable_all_optimization
360
        self.enable_fused_normalization = enable_fused_normalization
361
362
        self.enable_flash_attention = enable_flash_attention
        self.enable_jit_fused = enable_jit_fused
363
        self.enable_sequence_parallelism = enable_sequence_parallelism
364
365
366
        self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
        self.stage_manager = None
        self.schedule = None
367
        self.custom_policy = custom_policy
368
369
        assert zero_stage in (0, 1, 2)
        if self.pp_size > 1:
370
371
372
373
            assert (
                num_microbatches is not None or microbatch_size is not None
            ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
            assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
374
            self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
375
376
377
            self.schedule = OneForwardOneBackwardSchedule(
                self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
            )
378
379
        self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
        self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
380
        self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
381
382
383
384
385
386
387
388
389
390
391
        self.shard_config = ShardConfig(
            tensor_parallel_process_group=self.tp_group,
            pipeline_stage_manager=self.stage_manager,
            enable_tensor_parallelism=self.tp_size > 1,
            enable_all_optimization=self.enable_all_optimization,
            enable_fused_normalization=self.enable_fused_normalization,
            enable_flash_attention=self.enable_flash_attention,
            enable_jit_fused=self.enable_jit_fused,
            enable_sequence_parallelism=enable_sequence_parallelism,
            enable_sequence_overlap=enable_sequence_overlap,
        )
392
393
394
395
396
397
398
399
400
        self.amp_config = dict(
            initial_scale=initial_scale,
            growth_factor=growth_factor,
            backoff_factor=backoff_factor,
            growth_interval=growth_interval,
            hysteresis=hysteresis,
            min_scale=min_scale,
            max_scale=max_scale,
        )
401

402
403
404
405
406
407
408
409
        self.ddp_config = dict(
            broadcast_buffers=broadcast_buffers,
            bucket_cap_mb=ddp_bucket_cap_mb,
            find_unused_parameters=find_unused_parameters,
            check_reduction=check_reduction,
            gradient_as_bucket_view=gradient_as_bucket_view,
            static_graph=static_graph,
        )
410

411
412
413
414
415
416
417
        self.zero_config = dict(
            reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
            communication_dtype=communication_dtype,
            overlap_communication=overlap_communication,
            cpu_offload=cpu_offload,
            partition_grad=(self.zero_stage == 2),
        )
418

419
420
421
422
423
424
425
        self.max_norm = max_norm

    @property
    def enable_pipeline_parallelism(self) -> bool:
        return self.pp_size > 1

    def supported_devices(self) -> List[str]:
426
        return ["cuda"]
427
428

    def supported_precisions(self) -> List[str]:
429
        return ["fp16", "bf16", "fp32"]
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

    def control_device(self) -> bool:
        return True

    def control_precision(self) -> bool:
        return True

    def support_no_sync(self) -> bool:
        return False

    def control_checkpoint_io(self) -> bool:
        return True

    def configure(
        self,
        model: Module,
        optimizer: Optional[Optimizer] = None,
        criterion: Optional[Callable] = None,
        dataloader: Optional[DataLoader] = None,
        lr_scheduler: Optional[LRScheduler] = None,
    ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
451
        param_info = get_param_info(optimizer)
452
        if not isinstance(model, ModelWrapper):
453
            use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
454
455
456
            model = HybridParallelModule(
                model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
            )
457
458
        if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
            if self.zero_stage == 0:
459
460
461
462
463
464
465
466
467
468
469
470
471
                if self.precision in ["fp16", "bf16"]:
                    optimizer = HybridParallelAMPOptimizer(
                        optimizer,
                        model,
                        use_pipeline=self.enable_pipeline_parallelism,
                        param_info=param_info,
                        precision=self.precision,
                        max_norm=self.max_norm,
                        **self.amp_config,
                    )
                    self.checkpoint_io.link_master_and_working_param(
                        optimizer.working_to_master_map, optimizer.master_to_working_map
                    )
472
                else:
473
474
475
                    optimizer = HybridParallelNaiveOptimizer(
                        optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
                    )
476
            else:
477
                assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
                assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
                optimizer = HybridParallelZeroOptimizer(
                    optimizer,
                    model,
                    use_pipeline=self.enable_pipeline_parallelism,
                    param_info=param_info,
                    dp_process_group=self.dp_group,
                    tp_process_group=self.tp_group,
                    verbose=True,
                    clip_grad_norm=self.max_norm,
                    **self.zero_config,
                    **self.amp_config,
                )
                self.checkpoint_io.link_master_and_working_param(
                    optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param
                )
494

495
496
        return model, optimizer, criterion, dataloader, lr_scheduler

497
498
499
500
501
502
503
504
505
506
507
508
    def execute_pipeline(
        self,
        data_iter: Iterator,
        model: HybridParallelModule,
        criterion: Callable[[Any, Any], torch.Tensor],
        optimizer: Optional[
            Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer]
        ] = None,
        return_loss: bool = True,
        return_outputs: bool = False,
    ) -> dict:
        assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
509
510
511
        # return loss or outputs if needed
        ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
        with ctx:
512
513
514
            outputs = self.schedule.forward_backward_step(
                model, data_iter, criterion, optimizer, return_loss, return_outputs
            )
515
        model.sync_shared_params()
516
517
518
519
520
521
        if isinstance(optimizer, HybridParallelZeroOptimizer):
            optimizer.sync_grad()
        else:
            model.sync_grads()
        return outputs

522
523
524
    def prepare_dataloader(
        self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
    ):
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        r"""
        Prepare a dataloader for distributed training. The dataloader will be wrapped by
        `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.


        Args:
            dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
            shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
            seed (int, optional): Random worker seed for sampling, defaults to 1024.
            add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
            drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
                is not divisible by the batch size. If False and the size of dataset is not divisible by
                the batch size, then the last batch will be smaller, defaults to False.
            pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
            num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
            kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
                    `DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.

        Returns:
            :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
        """
        _kwargs = kwargs.copy()
547
548
549
        sampler = DistributedSampler(
            dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
        )
550
551
552
553
554
555
556
557

        # Deterministic dataloader
        def seed_worker(worker_id):
            worker_seed = seed
            np.random.seed(worker_seed)
            torch.manual_seed(worker_seed)
            random.seed(worker_seed)

558
559
560
561
562
563
564
565
566
567
        return DataLoader(
            dataset,
            batch_size=batch_size,
            sampler=sampler,
            worker_init_fn=seed_worker,
            drop_last=drop_last,
            pin_memory=pin_memory,
            num_workers=num_workers,
            **_kwargs,
        )
568
569

    def get_checkpoint_io(self) -> CheckpointIO:
570
        self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
571
        return self.checkpoint_io
572
573
574

    def no_sync(self, model: Module) -> Iterator[None]:
        raise NotImplementedError