ddp.py 4.85 KB
Newer Older
1
import os
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
2
import random
3
from collections import OrderedDict
4
from typing import Callable, Optional
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
5
6
7
8
9
10
11

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.replay_buffer import ReplayBuffer
from torch.utils.data import DataLoader
12
from transformers.modeling_utils import PreTrainedModel
13
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
14

15
16
17
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel

18
from .base import Strategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
19
20
21
from .sampler import DistributedSampler


22
23
24
25
26
27
28
29
30
31
# TODO Move this to a util.py   (Moving to ray.util introduces ringed import)
def get_grad_required_state_dict(model: nn.Module):
    state_dict = OrderedDict()
    for name, parameter in model.named_parameters():
        if parameter.requires_grad:
            state_dict[name] = parameter.detach()
    return state_dict


class DDPStrategy(Strategy):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
32
33
34
35
    """
        Strategy for distributed training using torch.distributed.
    """

36
37
38
39
    def __init__(self,
                 seed: int = 42,
                 plugin_initializer: Callable = TorchDDPPlugin
                 ) -> None:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
40
        self.seed = seed
41
42
        super().__init__(plugin_initializer)

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def _try_init_dist(self, force: bool = False) -> None:
        try:
            rank = int(os.environ['RANK'])
            local_rank = int(os.environ['LOCAL_RANK'])
            world_size = int(os.environ['WORLD_SIZE'])
            host = os.environ['MASTER_ADDR']
            port = int(os.environ['MASTER_PORT'])
            dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
            torch.cuda.set_device(local_rank)
        except KeyError as e:
            if force:
                raise RuntimeError(
                    f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
                )
        except Exception as e:
            if force:
                raise e

61
62
63
    def _post_init(self) -> None:
        assert isinstance(self.plugin, TorchDDPPlugin), \
            f'{type(self).__name__}\'s plugin is not initialized properly.'
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
64
65

    def setup_distributed(self) -> None:
66
        self._try_init_dist(force=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
67
68
69
70
71
72
73
74
        self.set_seed(self.seed)

    def set_seed(self, seed: int) -> None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
75
76
77
78
79
80
        return self.plugin.prepare_dataloader(replay_buffer,
                                              batch_size=replay_buffer.sample_batch_size,
                                              shuffle=True,
                                              drop_last=True,
                                              pin_memory=pin_memory,
                                              collate_fn=replay_buffer.collate_fn)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
81
82

    def setup_sampler(self, dataset) -> DistributedSampler:
83
        # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
84
        return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
85
86

    def unwrap_model(self, model: nn.Module) -> nn.Module:
87
88
        assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
        return model.unwrap()
89
90
91
92
93
94
95
96

    def save_pretrained(self,
                        model: nn.Module,
                        path: str,
                        only_rank0: bool = True,
                        tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
        if only_rank0 and dist.get_rank() != 0:
            return
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        unwrapped_model = self.unwrap_model(model)
        assert isinstance(unwrapped_model, PreTrainedModel)
        unwrapped_model.save_pretrained(path)
        if tokenizer is not None:
            tokenizer.save_pretrained(path)

    def get_model_state_dict_shard(self, model: nn.Module, **config):
        # TODO: implement sharding on naive strategy
        model = self.unwrap_model(model)
        if 'requires_grad_only' in config and config['requires_grad_only'] == True:
            state_dict = get_grad_required_state_dict(model)
        else:
            state_dict = model.state_dict()

        if 'shard_size' in config:
            shard_size = config['shard_size']
            accumulate_size = 0
            state_dict_shard = OrderedDict()
            for name, param in state_dict.items():
                state_dict_shard[name] = param
                accumulate_size += param.numel() * param.element_size()
                if accumulate_size >= shard_size:
                    accumulate_size = 0
                    yield state_dict_shard
                    state_dict_shard = OrderedDict()
            if accumulate_size > 0:
                yield state_dict_shard
        else:
            yield state_dict