colossalai.py 8.96 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
import warnings
2
from typing import Optional
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
3
4
5
6

import torch.nn as nn

import colossalai
7
8
9
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
10
11
12
13

from .ddp import DDPStrategy


14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class LowLevelZeroStrategy(DDPStrategy):
    """
        The strategy for training with ColossalAI.

    Args:
        stage(int): The stage to use in ZeRO. Choose in (1, 2)
        precision(str): The precision to use. Choose in ('fp32', 'fp16').
        seed(int): The seed for the random number generator.
        placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
                          If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
                          If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
        reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
        overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
        initial_scale(float): The initial scale for the optimizer.
        growth_factor(float): The growth factor for the optimizer.
        backoff_factor(float): The backoff factor for the optimizer.
        growth_interval(int): The growth interval for the optimizer.
        hysteresis(int): The hysteresis for the optimizer.
        min_scale(float): The minimum scale for the optimizer.
        max_scale(float): The maximum scale for the optimizer.
        max_norm(float): The maximum norm for the optimizer.
        norm_type(float): The norm type for the optimizer.

    """

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    def __init__(
        self,
        stage: int = 2,
        precision: str = "fp16",
        seed: int = 42,
        placement_policy: str = "cuda",
        reduce_bucket_size: int = 12 * 1024**2,  # only for stage 1&2
        overlap_communication: bool = True,  # only for stage 1&2
        initial_scale: float = 2**16,
        growth_factor: float = 2,
        backoff_factor: float = 0.5,
        growth_interval: int = 1000,
        hysteresis: int = 2,
        min_scale: float = 1,
        max_scale: float = 2**32,
        max_norm: float = 0.0,
        norm_type: float = 2.0,
    ) -> None:
57
        assert stage in (1, 2), f'Unsupported stage "{stage}"'
58
59
        assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
        assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
60
61
62
63
64
65

        plugin_initializer = lambda: LowLevelZeroPlugin(
            stage=stage,
            precision=precision,
            reduce_bucket_size_in_m=reduce_bucket_size,
            overlap_communication=overlap_communication,
66
            cpu_offload=(placement_policy == "cpu"),
67
68
69
70
71
72
73
74
            initial_scale=initial_scale,
            growth_factor=growth_factor,
            backoff_factor=backoff_factor,
            growth_interval=growth_interval,
            hysteresis=hysteresis,
            min_scale=min_scale,
            max_scale=max_scale,
            max_norm=max_norm,
75
            norm_type=norm_type,
76
77
78
79
80
        )

        super().__init__(seed, plugin_initializer)

    def _post_init(self) -> None:
81
82
83
        assert isinstance(
            self.plugin, LowLevelZeroPlugin
        ), f"{type(self).__name__}'s plugin is not initialized properly."
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    def setup_distributed(self) -> None:
        colossalai.launch_from_torch({}, seed=self.seed)

    def unwrap_model(self, model: nn.Module) -> nn.Module:
        assert isinstance(model, LowLevelZeroModel)
        return model.module

    def get_model_state_dict_shard(self, model: nn.Module, **config):
        assert isinstance(model, LowLevelZeroModel)
        yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)


class GeminiStrategy(DDPStrategy):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
98
99
100
101
102
103
    """
        The strategy for training with ColossalAI.

    Args:
        seed(int): The seed for the random number generator.
        shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
104
            This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
105
106
107
108
109
        placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
                          If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
                          If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
        pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
        force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
110
        search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
111
        hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
112
        min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
113
114
115
116
117
118
119
120
121
122
123
124
125
        gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
        initial_scale(float): The initial scale for the optimizer.
        growth_factor(float): The growth factor for the optimizer.
        backoff_factor(float): The backoff factor for the optimizer.
        growth_interval(int): The growth interval for the optimizer.
        hysteresis(int): The hysteresis for the optimizer.
        min_scale(float): The minimum scale for the optimizer.
        max_scale(float): The maximum scale for the optimizer.
        max_norm(float): The maximum norm for the optimizer.
        norm_type(float): The norm type for the optimizer.

    """

126
127
128
129
    def __init__(
        self,
        seed: int = 42,
        shard_init: bool = False,  # only for stage 3
130
        placement_policy: str = "auto",
131
132
133
        shard_param_frac: float = 1.0,  # only for static placement
        offload_optim_frac: float = 0.0,  # only for static placement
        offload_param_frac: float = 0.0,  # only for static placement
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        pin_memory: bool = True,  # only for stage 3
        force_outputs_fp32: bool = False,  # only for stage 3
        search_range_m: int = 32,  # only for stage 3
        hidden_dim: Optional[int] = None,  # only for stage 3
        min_chunk_size_m: float = 32,  # only for stage 3
        gpu_margin_mem_ratio: float = 0.0,  # only for stage 3
        initial_scale: float = 2**16,
        growth_factor: float = 2,
        backoff_factor: float = 0.5,
        growth_interval: int = 1000,
        hysteresis: int = 2,
        min_scale: float = 1,
        max_scale: float = 2**32,
        max_norm: float = 0.0,
        norm_type: float = 2.0,
    ) -> None:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
150
151
        # TODO(ver217): support shard_init when using from_pretrained()
        if shard_init:
152
            warnings.warn(
153
154
                f"Shard init is not supported model.from_pretrained() yet. "
                "Please load weights after strategy.prepare()"
155
            )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
156
        self.shard_init = shard_init
157

158
        warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
159

160
161
162
163
164
165
166
167
168
169
        # colossalai has changed api for get_current_device in 0.3.4 version or newer
        try:
            from colossalai.accelerator import get_accelerator

            chunk_init_device = get_accelerator().get_current_device()
        except:
            from colossalai.utils import get_current_device

            chunk_init_device = get_current_device()

170
        # NOTE: dist should be initialized before calling get_current_device()
171
        plugin_initializer = lambda: GeminiPlugin(
172
            chunk_init_device=chunk_init_device,
173
            placement_policy=placement_policy,
174
175
176
            shard_param_frac=shard_param_frac,
            offload_optim_frac=offload_optim_frac,
            offload_param_frac=offload_param_frac,
177
            precision="fp16",
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
            pin_memory=pin_memory,
            force_outputs_fp32=force_outputs_fp32,
            strict_ddp_mode=shard_init,
            search_range_m=search_range_m,
            hidden_dim=hidden_dim,
            min_chunk_size_m=min_chunk_size_m,
            gpu_margin_mem_ratio=gpu_margin_mem_ratio,
            initial_scale=initial_scale,
            growth_factor=growth_factor,
            backoff_factor=backoff_factor,
            growth_interval=growth_interval,
            hysteresis=hysteresis,
            min_scale=min_scale,
            max_scale=max_scale,
            max_norm=max_norm,
193
            norm_type=norm_type,
194
        )
195
196
197
198

        super().__init__(seed, plugin_initializer)

    def _post_init(self) -> None:
199
        assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
200
201
202
203
204

    def setup_distributed(self) -> None:
        colossalai.launch_from_torch({}, seed=self.seed)

    def model_init_context(self):
205
        return super().model_init_context()
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
206

207
    def unwrap_model(self, model: nn.Module) -> nn.Module:
208
209
        assert isinstance(model, GeminiDDP)
        return model.module