manager.py 11.4 KB
Newer Older
1
2
3
4
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple

import torch
5
6
import torch.distributed as dist
from torch.distributed import ProcessGroup
7

8
from colossalai.utils import free_storage, get_current_device
9

10
11
from .chunk import Chunk, ChunkFullError, TensorState

12
13
14
15
16
17
18
19
20
21

class ChunkManager:
    """
    A manager class to manipulate the tensors in chunks.

    Args:
        chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
        init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
    """

22
    def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
23
        self.device = init_device or get_current_device()
24
        self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
25
26
        self.kwargs_config = chunk_configuration
        for k, v in self.kwargs_config.items():
27
28
            self.dp_degree_chunk_size_dict[k] = v.pop("chunk_size")
            v["init_device"] = self.device
29

30
        self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
31
32
33
        self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
        self.accessed_chunks: Set[Chunk] = set()
        self.accessed_mem: int = 0
34
35
36
37
38
39
40
        self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}

    def register_tensor(
        self,
        tensor: torch.Tensor,
        group_type: str,
        config_key: int,
41
42
        zero_group: ProcessGroup,
        extra_dp_group: ProcessGroup = None,
43
44
45
        cpu_offload: bool = False,
        pin_memory: bool = False,
    ) -> None:
46
47
48
        """
        Register a tensor to the chunk manager.
        Then, the tensor should be accessed by `get_chunks`.
49
50
51

        Args:
            tensor: the tensor appended to the chunk
52
53
            group_type: the data type of the group.
            config_key: the key of the group's name, the size of the dp world
54
55
56
57
            cpu_offload: if True, the chunk will be closed on CPU
            pin_memory: whether the chunk is pinned in the cpu memory
        """
        assert tensor not in self.tensor_chunk_map
58
        assert isinstance(tensor, torch.Tensor), "Please feed Tensor to this ChunkManager"
59
        assert config_key in self.dp_degree_chunk_size_dict
60

61
        chunk_size = self.dp_degree_chunk_size_dict[config_key]
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        chunk_kwargs = self.kwargs_config[config_key]
        group_name = "{}_{}".format(group_type, config_key)
        chunk_group = self.__get_chunk_group(group_name)

        try:
            # append the tensor to the last chunk
            chunk_group[-1].append_tensor(tensor)
        except (IndexError, ChunkFullError):
            # the except statement will be triggered when there is no chunk or
            # the last chunk in the chunk group is full
            # this will create a new chunk and allocate this chunk to its corresponding process
            if chunk_group:
                # the chunk group is not empty
                # close the last chunk
                self.__close_one_chunk(chunk_group[-1])

            if tensor.numel() > chunk_size:
                chunk_size = tensor.numel()
80
                dp_size = dist.get_world_size(zero_group)
81
82
                chunk_size = chunk_size + (-chunk_size % dp_size)

83
84
            chunk = Chunk(
                chunk_size=chunk_size,
85
                zero_group=zero_group,
86
87
88
                dtype=tensor.dtype,
                cpu_shard_init=cpu_offload,
                pin_memory=pin_memory,
89
                extra_dp_group=extra_dp_group,
90
91
92
93
94
95
96
97
98
99
                **chunk_kwargs,
            )

            chunk_group.append(chunk)
            chunk.append_tensor(tensor)
            self.__add_memory_usage(chunk.memory_usage)

        self.tensor_chunk_map[tensor] = chunk_group[-1]

    def close_all_groups(self):
100
        """Close all the chunks of all groups."""
101
102
103
104
        for group_name in self.chunk_groups:
            self.__close_one_chunk(self.chunk_groups[group_name][-1])

    def access_chunk(self, chunk: Chunk) -> None:
105
        """Make the chunk can be used for calculation."""
106
107
        if chunk in self.accessed_chunks:
            return
108
        self.__sub_memory_usage(chunk.memory_usage)
109
        if chunk.device_type == "cpu":
110
111
112
113
114
            chunk.shard_move(get_current_device())
        self.__add_accessed_chunk(chunk)
        self.__add_memory_usage(chunk.memory_usage)

    def release_chunk(self, chunk: Chunk) -> None:
115
        """Scatter the chunk in CUDA."""
116
117
118
        if chunk not in self.accessed_chunks:
            return
        if chunk.can_release:
119
            self.__sub_memory_usage(chunk.memory_usage)
120
121
122
123
            self.__sub_accessed_chunk(chunk)
            self.__add_memory_usage(chunk.memory_usage)

    def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
124
        """Move the shard of the chunk to the target device."""
125
126
        if not chunk.can_move or chunk.device_type == device.type:
            return
127
        self.__sub_memory_usage(chunk.memory_usage)
128
129
130
131
        chunk.shard_move(device, force_copy)
        self.__add_memory_usage(chunk.memory_usage)

    def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
132
        """Transit tensor state according to pre-defined state machine."""
133
134
135
136
        chunk = self.tensor_chunk_map[tensor]
        chunk.tensor_trans_state(tensor, state)

    def reduce_chunk(self, chunk: Chunk) -> bool:
137
        """Reduce or all reduce the chunk."""
138
139
        if not chunk.can_reduce:
            return False
140
        self.__sub_memory_usage(chunk.memory_usage)
141
142
143
144
145
        chunk.reduce()
        self.__sub_accessed_chunk(chunk)
        self.__add_memory_usage(chunk.memory_usage)
        return True

146
147
148
149
150
151
152
153
    def fake_release_chunk(self, chunk: Chunk) -> None:
        """Release gathered chunk in a fake mode.
        This function is used for keep-gathered chunk in the inference mode.
        """
        assert chunk.keep_gathered
        assert chunk.tensor_state_cnter[TensorState.HOLD] == chunk.num_tensors
        self.__sub_accessed_chunk(chunk)

154
155
156
157
158
    def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
        """
        Copy data to the chunk.

        Args:
159
            tensor (torch.Tensor): the tensor used to retrieve meta information
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
            data (torch.Tensor): the tensor to be copied to the chunk
        """
        chunk = self.tensor_chunk_map[tensor]
        chunk.copy_tensor_to_chunk_slice(tensor, data)

    def get_chunk(self, tensor: torch.Tensor) -> Chunk:
        """
        Return the chunk owning the tensor.

        Args:
            tensor (torch.Tensor): a torch tensor object
        """
        return self.tensor_chunk_map[tensor]

    def get_cuda_movable_chunks(self) -> List[Chunk]:
        """
        Get all chunks that can be moved.
        """
        chunk_list = []
        for chunk in self.accessed_chunks:
            if chunk.can_release:
                chunk_list.append(chunk)
        chunk_list.sort(key=lambda x: x.count_id)
        return chunk_list

    def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
        """
        Get all chunks owning the input tensors.

        Args:
            tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
        """
        chunks = []
        for tensor in tensors:
            chunk = self.get_chunk(tensor)
            if chunk not in chunks:
                chunks.append(chunk)
        return tuple(chunks)

    def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
        """Add extern static tensor to chunk manager.
        Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
        They are "static", which means their shape, dtype, device never change.
        Thus, their memory usage never changes.

        Args:
            tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
        """
        assert tensor not in self.tensor_chunk_map
209
210
211
212
        device_type = tensor.device.type
        if device_type == "npu":
            device_type = "cuda"
        self.total_mem[device_type] += tensor.numel() * tensor.element_size()
213
214
215

    def __repr__(self) -> str:
        msg = [
216
217
            "Chunk Manager Information:\n",
            "Total memory: " + ", ".join([f"{k}={v}B" for k, v in self.total_mem.items()]) + "\n",
218
219
        ]
        for group_name, group in self.chunk_groups.items():
220
            msg.append(f"Group {group_name}:\n")
221
            for i, chunk in enumerate(group):
222
223
                msg.append(f"[{i}] {chunk}\n")
        return "".join(msg)
224

225
    def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
226
        """Register a chunk group."""
227
228
229
230
231
        if group_name not in self.chunk_groups:
            self.chunk_groups[group_name] = deque()
        return self.chunk_groups[group_name]

    def __close_one_chunk(self, chunk: Chunk):
232
        self.__sub_memory_usage(chunk.memory_usage)
233
234
235
        chunk.close_chunk()
        self.__add_memory_usage(chunk.memory_usage)

236
    def __sub_memory_usage(self, usage: Dict[str, int]):
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        for k, v in usage.items():
            self.total_mem[k] -= v

    def __add_memory_usage(self, usage: Dict[str, int]):
        for k, v in usage.items():
            self.total_mem[k] += v

    def __add_accessed_chunk(self, chunk: Chunk):
        chunk.access_chunk()
        self.accessed_chunks.add(chunk)
        self.accessed_mem += chunk.chunk_mem

    def __sub_accessed_chunk(self, chunk: Chunk):
        chunk.release_chunk()
        self.accessed_chunks.remove(chunk)
        self.accessed_mem -= chunk.chunk_mem
253
254
255
256
257
258
259
260
261
262

    def init_grad_chunk(self, chunk: Chunk) -> Chunk:
        if chunk.grad_chunk is not None:
            self.__sub_memory_usage(chunk.grad_chunk.memory_usage)
        grad_chunk = chunk.init_grad_chunk()
        self.__add_memory_usage(grad_chunk.memory_usage)
        if grad_chunk not in self.accessed_chunks:
            self.accessed_chunks.add(grad_chunk)
            self.accessed_mem += grad_chunk.chunk_mem
        return grad_chunk
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

    def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
        """Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction."""

        assert chunk.grad_chunk is not None

        # Make a backup for gradient accumulated before.
        # Here backup gradients should be multiplied, since it will be divided after gradient reduction.
        if chunk.grad_chunk.is_gathered:
            accumulated_grad = chunk.grad_chunk.cuda_global_chunk.clone().detach().mul_(chunk.pg_size)
            accumulated_grad_gathered = True
        else:
            if chunk.grad_chunk.cuda_shard is not None:
                accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size)
            else:
                accumulated_grad = (
                    chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size)
                )
            accumulated_grad_gathered = False

        # Reset grad_chunk, and chunk.grad_chunk will be accessed.
        grad_chunk = self.init_grad_chunk(chunk)
        grad_chunk.cuda_global_chunk.zero_()

        # Add backup gradients to grad_chunk.
        if accumulated_grad_gathered:
            grad_chunk.cuda_global_chunk.add_(accumulated_grad)
        else:
            grad_chunk.cuda_global_chunk[grad_chunk.shard_begin : grad_chunk.shard_end].add_(accumulated_grad)

        # Release accumulated_grad
        free_storage(accumulated_grad)

296
        return grad_chunk