ddp.py 5.87 KB
Newer Older
1
import os
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
2
import random
3
from collections import OrderedDict
4
from typing import Callable, Optional
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
5
6
7
8
9

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
10
11
from coati.experience_buffer import ExperienceBuffer
from coati.models import Actor, Critic, RewardModel
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
12
from torch.utils.data import DataLoader
13
from transformers.modeling_utils import PreTrainedModel
14
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
15

16
17
18
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel

19
from .base import Strategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
20
21
22
from .sampler import DistributedSampler


23
24
25
26
27
28
29
30
31
32
# TODO Move this to a util.py   (Moving to ray.util introduces ringed import)
def get_grad_required_state_dict(model: nn.Module):
    state_dict = OrderedDict()
    for name, parameter in model.named_parameters():
        if parameter.requires_grad:
            state_dict[name] = parameter.detach()
    return state_dict


class DDPStrategy(Strategy):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
33
34
35
36
    """
        Strategy for distributed training using torch.distributed.
    """

37
38
39
40
    def __init__(self,
                 seed: int = 42,
                 plugin_initializer: Callable = TorchDDPPlugin
                 ) -> None:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
41
        self.seed = seed
42
43
        super().__init__(plugin_initializer)

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    def _try_init_dist(self, force: bool = False) -> None:
        try:
            rank = int(os.environ['RANK'])
            local_rank = int(os.environ['LOCAL_RANK'])
            world_size = int(os.environ['WORLD_SIZE'])
            host = os.environ['MASTER_ADDR']
            port = int(os.environ['MASTER_PORT'])
            dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
            torch.cuda.set_device(local_rank)
        except KeyError as e:
            if force:
                raise RuntimeError(
                    f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
                )
        except Exception as e:
            if force:
                raise e

62
63
64
    def _post_init(self) -> None:
        assert isinstance(self.plugin, TorchDDPPlugin), \
            f'{type(self).__name__}\'s plugin is not initialized properly.'
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
65
66

    def setup_distributed(self) -> None:
67
        self._try_init_dist(force=True)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
68
69
70
71
72
73
74
        self.set_seed(self.seed)

    def set_seed(self, seed: int) -> None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

75
76
77
    def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
        return self.plugin.prepare_dataloader(data_buffer,
                                              batch_size=data_buffer.sample_batch_size,
78
79
80
                                              shuffle=True,
                                              drop_last=True,
                                              pin_memory=pin_memory,
81
                                              collate_fn=data_buffer.collate_fn)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
82
83

    def setup_sampler(self, dataset) -> DistributedSampler:
84
        # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
85
        return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
86
87

    def unwrap_model(self, model: nn.Module) -> nn.Module:
88
89
        assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
        return model.unwrap()
90
91
92
93
94
95

    def save_pretrained(self,
                        model: nn.Module,
                        path: str,
                        only_rank0: bool = True,
                        tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        if not only_rank0 or dist.get_rank() == 0:
            unwrapped_model = self.unwrap_model(model)
            assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
            pretrained_model = unwrapped_model.model
            assert isinstance(pretrained_model, PreTrainedModel)
            # HACK: only use hf save_pretrained to save config
            pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
            if tokenizer is not None:
                tokenizer.save_pretrained(path)
        model_path = os.path.join(path, "pytorch_model.bin")
        self.save_model(model,
                        model_path,
                        only_rank0=only_rank0)

        def _replace_keys(model_path: str,
                          replace_fn: Callable):
            state_dict = torch.load(model_path, map_location="cpu")
            state_dict = {
                replace_fn(k): v
                for k, v in state_dict.items()
            }
            torch.save(state_dict, model_path)

        # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
        # HACK: rename keys of pytorch_model.bin
        if dist.get_rank() == 0:
            _replace_keys(model_path, lambda k: k.replace("model.", "", 1))
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    def get_model_state_dict_shard(self, model: nn.Module, **config):
        # TODO: implement sharding on naive strategy
        model = self.unwrap_model(model)
        if 'requires_grad_only' in config and config['requires_grad_only'] == True:
            state_dict = get_grad_required_state_dict(model)
        else:
            state_dict = model.state_dict()

        if 'shard_size' in config:
            shard_size = config['shard_size']
            accumulate_size = 0
            state_dict_shard = OrderedDict()
            for name, param in state_dict.items():
                state_dict_shard[name] = param
                accumulate_size += param.numel() * param.element_size()
                if accumulate_size >= shard_size:
                    accumulate_size = 0
                    yield state_dict_shard
                    state_dict_shard = OrderedDict()
            if accumulate_size > 0:
                yield state_dict_shard
        else:
            yield state_dict