hybrid_parallel_plugin.py 25.5 KB
Newer Older
1
2
import random
from contextlib import nullcontext
3
from functools import partial
4
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
5
6
7
8
9

import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
10
11
from torch.nn import Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
12
13
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
14
from torch.utils._pytree import tree_map
15
16
17
18
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
19
from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO
20
21
22
23
24
25
26
27
28
29
30
31
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.zero.low_level import LowLevelZeroOptimizer

from .pp_plugin_base import PipelinePluginBase

DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2


32
33
34
35
36
37
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
    if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
        return x.to(dtype)
    return x


38
39
class HybridParallelModule(ModelWrapper):

40
41
    def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
                 ddp_config: dict) -> None:
42

43
44
        self.stage_manager = shard_config.pipeline_stage_manager
        self.dp_group = dp_group
45

46
47
        shardformer = ShardFormer(shard_config)
        module, self.shared_params = shardformer.optimize(module)
48
49

        # setting process groups for shared parameters
50
51
52
        self.shared_param_process_groups = []
        for shared_param in self.shared_params:
            if len(shared_param) > 0:
53
54
                self.shared_param_process_groups.append(
                    self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
55
56
57

        # setting mixed_precision
        self.mixed_precision = None
58
        if precision == 'fp16':
59
            self.mixed_precision = torch.float16
60
        elif precision == 'bf16':
61
62
63
64
            self.mixed_precision = torch.bfloat16
        if self.mixed_precision is not None:
            module = module.to(self.mixed_precision)
        module = module.cuda()
65

66
67
68
69
        # setting input type cast when using mixed precision
        self.convert_fn = None
        if self.mixed_precision is not None:
            self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision)
70

71
72
        # setting ddp configs
        if use_ddp:
73
74
75
76
77
            # convert model to sync bn
            module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
            # wrap the model with PyTorch DDP
            module = DDP(module, process_group=dp_group, **ddp_config)

78
79
80
81
        super().__init__(module)

    def sync_shared_params(self):
        for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
82
83
84
85
            if self.stage_manager.stage in shared_param:
                param = shared_param[self.stage_manager.stage]
                dist.all_reduce(param.grad, group=group)
            dist.barrier()
86
87
88
89
90
91
92
93
94
95
96
97

    def no_sync(self) -> Iterator[None]:
        # no sync grads across data parallel
        return nullcontext()

    def sync_grads(self):
        # sync grad across data parallel
        if self.dp_group.size() == 1:
            return
        for p in self.module.parameters():
            if p.grad is not None:
                dist.all_reduce(p.grad, group=self.dp_group)
98
                p.grad.div_(self.dp_group.size())
99

100
101
102
103
104
105
    def forward(self, *args, **kwargs):
        if self.convert_fn is not None:
            args = tree_map(self.convert_fn, args)
            kwargs = tree_map(self.convert_fn, kwargs)
        return super().forward(*args, **kwargs)

106
107
108
109
110
111
    def unwrap(self):
        module = super().unwrap()
        if isinstance(module, DDP):
            module = module.module
        return module

112

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def get_param_info(optim: Optimizer):
    # Get a backup of necessary information of parameters for future use, which includes:
    # 1. A complete param_group, with params in the form of param_id
    # 2. A mapping from param address (obtained using id(param)) to integer param_id
    # 3. A mapping from integer param_id to param address.
    # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
    # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.

    if optim is None:
        return {}
    param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
    start_index = 0
    for group in optim.param_groups:

        packed_group = {k: v for k, v in group.items() if k != 'params'}
        packed_group['params'] = []

        for param_id, param in enumerate(group['params'], start_index):
            original_shape = param.shape if isinstance(param, torch.Tensor) else None
            packed_group['params'].append(param_id)
            param_info['param2id'][id(param)] = param_id
            param_info['id2param'][param_id] = id(param)
            param_info['param2shape'][id(param)] = original_shape

        param_info['param_groups'].append(packed_group)
        start_index += len(group['params'])

    return param_info


143
def init_pipeline_optimizer(optim: Optimizer, model: Module):
144
    model_params = set(model.parameters())
145
146
    new_param_groups = []
    for group in optim.param_groups:
147
        params = [p for p in group['params'] if p in model_params]
148
149
150
151
        new_param_groups.append({**group, 'params': params})
    optim.__setstate__({'param_groups': new_param_groups})


152
153
class HybridParallelNaiveOptimizer(OptimizerWrapper):

154
155
    def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
        self.param_info = param_info
156
157
158
159
160
161
        if use_pipeline:
            init_pipeline_optimizer(optim, model)
        super().__init__(optim)


class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
162
163
164
165
166

    def __init__(self,
                 optim: Optimizer,
                 model: Module,
                 use_pipeline: bool,
167
                 param_info: OrderedDict,
168
169
170
171
172
173
174
175
176
                 precision: str = 'fp16',
                 initial_scale: float = 2**16,
                 min_scale: float = 1,
                 growth_factor: float = 2,
                 backoff_factor: float = 0.5,
                 growth_interval: int = 1000,
                 hysteresis: int = 2,
                 max_scale: float = 2**32,
                 max_norm: float = 0):
177
        self.param_info = param_info
178
179
180
181
182
183
184
185
186
187
188
189
190
        if use_pipeline:
            init_pipeline_optimizer(optim, model)
        super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
                         hysteresis, max_scale, max_norm)


class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):

    def __init__(
            self,
            optimizer: Optimizer,
            model: Module,
            use_pipeline: bool,
191
            param_info: OrderedDict,
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
            initial_scale: int = 2**16,    # grad scaler config
            min_scale: int = 1,
            growth_factor: float = 2.,
            backoff_factor: float = .5,
            growth_interval: int = 2000,
            hysteresis: int = 2,
            max_scale: int = 2**24,
            clip_grad_norm: float = 0.0,    # grad clipping
            verbose: bool = False,
            reduce_bucket_size: int = 1024 * 1024,    # communication
            communication_dtype: Optional[torch.dtype] = None,
            overlap_communication: bool = True,
            partition_grad: bool = False,    # stage 2 flag
            cpu_offload: bool = False,    # cpu offload
            dp_process_group: Optional[ProcessGroup] = None,    # the dp pg for comm
            tp_process_group: Optional[ProcessGroup] = None,    # if using tp
            forced_dtype: Optional[torch.dtype] = None):
209
        self.param_info = param_info
210
211
212
213
214
215
216
217
218
        if use_pipeline:
            init_pipeline_optimizer(optimizer, model)
        super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
                         hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype,
                         overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group,
                         forced_dtype)


class HybridParallelPlugin(PipelinePluginBase):
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    """
    Plugin for Hybrid Parallel Training.
    Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
    The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).

    Example:
        >>> from colossalai.booster import Booster
        >>> from colossalai.booster.plugin import HybridParallelPlugin

        >>> model, train_dataset, optimizer, criterion = ...
        >>> plugin =  HybridParallelPlugin(tp_size=2, pp_size=2)

        >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
        >>> booster = Booster(plugin=plugin)
        >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)

    Args:
        tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
        pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
        precision (str, optional): Specifies the precision of parameters during training.
                                    Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
                                    Defaults to 'fp16'.
        zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
                                        When set to 0, ZeRO will not be used. Defaults to 0.
        enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
                                                    Currently all the optimization methods include fused normalization, flash attention and JIT.
                                                    Defaults to False.
        enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False.
        enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
        enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
        num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
        initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
        min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
        growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
        backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
        growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
        hysteresis (int, optional):  The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
        max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
        max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
258
259
260
261
262
263
264
265
266
267
        broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
        ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
        find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
        check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
        gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
        static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
        zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
        cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
        communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
        overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    """

    def __init__(self,
                 tp_size: int,
                 pp_size: int,
                 precision: str = 'fp16',
                 zero_stage: int = 0,
                 enable_all_optimization: bool = False,
                 enable_fused_normalization: bool = False,
                 enable_flash_attention: bool = False,
                 enable_jit_fused: bool = False,
                 enable_sequence_parallelism: bool = False,
                 num_microbatches: Optional[int] = None,
                 initial_scale: float = 2**16,
                 min_scale: float = 1,
                 growth_factor: float = 2,
                 backoff_factor: float = 0.5,
                 growth_interval: int = 1000,
                 hysteresis: int = 2,
                 max_scale: float = 2**32,
                 max_norm: float = 0,
289
290
291
292
293
294
295
296
297
298
                 broadcast_buffers: bool = True,
                 ddp_bucket_cap_mb: int = 25,
                 find_unused_parameters: bool = False,
                 check_reduction: bool = False,
                 gradient_as_bucket_view: bool = False,
                 static_graph: bool = False,
                 zero_bucket_size_in_m: int = 12,
                 cpu_offload: bool = False,
                 communication_dtype: Optional[torch.dtype] = None,
                 overlap_communication: bool = True) -> None:
299
300
301
302
303

        super().__init__()
        assert dist.get_world_size() % (
            tp_size * pp_size
        ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
304
305
306
307

        if enable_sequence_parallelism:
            assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'

308
309
310
311
312
313
        self.tp_size = tp_size
        self.pp_size = pp_size
        self.dp_size = dist.get_world_size() // (tp_size * pp_size)
        self.precision = precision
        self.zero_stage = zero_stage
        self.cpu_offload = cpu_offload
314
        self.enable_all_optimization = enable_all_optimization
315
        self.enable_fused_normalization = enable_fused_normalization
316
317
        self.enable_flash_attention = enable_flash_attention
        self.enable_jit_fused = enable_jit_fused
318
        self.enable_sequence_parallelism = enable_sequence_parallelism
319
320
321
322
323
324
325
326
327
328
329
        self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
        self.stage_manager = None
        self.schedule = None
        assert zero_stage in (0, 1, 2)
        if self.pp_size > 1:
            assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
            assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
            self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
            self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
        self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
        self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
330
        self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
331
332
333
        self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
                                        pipeline_stage_manager=self.stage_manager,
                                        enable_tensor_parallelism=self.tp_size > 1,
334
335
336
                                        enable_all_optimization=self.enable_all_optimization,
                                        enable_fused_normalization=self.enable_fused_normalization,
                                        enable_flash_attention=self.enable_flash_attention,
337
338
                                        enable_jit_fused=self.enable_jit_fused,
                                        enable_sequence_parallelism=enable_sequence_parallelism)
339
340
341
342
343
344
345
346
347
        self.amp_config = dict(
            initial_scale=initial_scale,
            growth_factor=growth_factor,
            backoff_factor=backoff_factor,
            growth_interval=growth_interval,
            hysteresis=hysteresis,
            min_scale=min_scale,
            max_scale=max_scale,
        )
348
349

        self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
350
                               bucket_cap_mb=ddp_bucket_cap_mb,
351
352
353
354
                               find_unused_parameters=find_unused_parameters,
                               check_reduction=check_reduction,
                               gradient_as_bucket_view=gradient_as_bucket_view,
                               static_graph=static_graph)
355
356
357
358
359
360
361

        self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
                                communication_dtype=communication_dtype,
                                overlap_communication=overlap_communication,
                                cpu_offload=cpu_offload,
                                partition_grad=(self.zero_stage == 2))

362
363
364
365
366
367
368
369
370
371
        self.max_norm = max_norm

    @property
    def enable_pipeline_parallelism(self) -> bool:
        return self.pp_size > 1

    def supported_devices(self) -> List[str]:
        return ['cuda']

    def supported_precisions(self) -> List[str]:
372
        return ['fp16', 'bf16', 'fp32']
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393

    def control_device(self) -> bool:
        return True

    def control_precision(self) -> bool:
        return True

    def support_no_sync(self) -> bool:
        return False

    def control_checkpoint_io(self) -> bool:
        return True

    def configure(
        self,
        model: Module,
        optimizer: Optional[Optimizer] = None,
        criterion: Optional[Callable] = None,
        dataloader: Optional[DataLoader] = None,
        lr_scheduler: Optional[LRScheduler] = None,
    ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
394
        param_info = get_param_info(optimizer)
395
        if not isinstance(model, ModelWrapper):
396
397
398
            use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
            model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
                                         self.ddp_config)
399
400
        if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
            if self.zero_stage == 0:
401
402
403
404
                if self.precision in ['fp16', 'bf16']:
                    optimizer = HybridParallelAMPOptimizer(optimizer,
                                                           model,
                                                           use_pipeline=self.enable_pipeline_parallelism,
405
                                                           param_info=param_info,
406
407
408
                                                           precision=self.precision,
                                                           max_norm=self.max_norm,
                                                           **self.amp_config)
409
410
                    self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
                                                                     optimizer.master_to_working_map)
411
412
413
                else:
                    optimizer = HybridParallelNaiveOptimizer(optimizer,
                                                             model,
414
415
                                                             use_pipeline=self.enable_pipeline_parallelism,
                                                             param_info=param_info)
416
            else:
417
418
                assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
                assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
419
420
421
                optimizer = HybridParallelZeroOptimizer(optimizer,
                                                        model,
                                                        use_pipeline=self.enable_pipeline_parallelism,
422
                                                        param_info=param_info,
423
424
425
426
                                                        dp_process_group=self.dp_group,
                                                        tp_process_group=self.tp_group,
                                                        verbose=True,
                                                        clip_grad_norm=self.max_norm,
427
                                                        **self.zero_config,
428
                                                        **self.amp_config)
429
430
431
                self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
                                                                 optimizer._param_store.master_to_working_param)

432
433
434
435
436
437
        return model, optimizer, criterion, dataloader, lr_scheduler

    def execute_pipeline(self,
                         data_iter: Iterator,
                         model: HybridParallelModule,
                         criterion: Callable[[Any, Any], torch.Tensor],
438
439
                         optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
                                          HybridParallelZeroOptimizer],
440
441
442
443
444
445
446
447
                         return_loss: bool = True,
                         return_outputs: bool = False) -> dict:
        assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
        # return loss or outputs if needed
        ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
        with ctx:
            outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss,
                                                          return_outputs)
448
        model.sync_shared_params()
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
        if isinstance(optimizer, HybridParallelZeroOptimizer):
            optimizer.sync_grad()
        else:
            model.sync_grads()
        return outputs

    def prepare_dataloader(self,
                           dataset,
                           batch_size,
                           shuffle=False,
                           seed=1024,
                           drop_last=False,
                           pin_memory=False,
                           num_workers=0,
                           **kwargs):
        r"""
        Prepare a dataloader for distributed training. The dataloader will be wrapped by
        `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.


        Args:
            dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
            shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
            seed (int, optional): Random worker seed for sampling, defaults to 1024.
            add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
            drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
                is not divisible by the batch size. If False and the size of dataset is not divisible by
                the batch size, then the last batch will be smaller, defaults to False.
            pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
            num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
            kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
                    `DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.

        Returns:
            :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
        """
        _kwargs = kwargs.copy()
        sampler = DistributedSampler(dataset,
                                     num_replicas=self.pg_mesh.size(DP_AXIS),
                                     rank=self.pg_mesh.coordinate(DP_AXIS),
                                     shuffle=shuffle)

        # Deterministic dataloader
        def seed_worker(worker_id):
            worker_seed = seed
            np.random.seed(worker_seed)
            torch.manual_seed(worker_seed)
            random.seed(worker_seed)

        return DataLoader(dataset,
                          batch_size=batch_size,
                          sampler=sampler,
                          worker_init_fn=seed_worker,
                          drop_last=drop_last,
                          pin_memory=pin_memory,
                          num_workers=num_workers,
                          **_kwargs)

    def get_checkpoint_io(self) -> CheckpointIO:
508
509
        self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
        return self.checkpoint_io
510
511
512

    def no_sync(self, model: Module) -> Iterator[None]:
        raise NotImplementedError