manager.py 9.31 KB
Newer Older
1
2
3
4
5
6
7
8
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple

import torch

from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device

9
10
from .chunk import Chunk, ChunkFullError, TensorState

11
12
13
14
15
16
17
18
19
20

class ChunkManager:
    """
    A manager class to manipulate the tensors in chunks.

    Args:
        chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
        init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
    """

21
    def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
22
23

        self.device = init_device or get_current_device()
24
        self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
25
26
        self.kwargs_config = chunk_configuration
        for k, v in self.kwargs_config.items():
27
            self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
28
29
30
31
32
33
34
35
            v['init_device'] = self.device

        self.chunk_groups: Dict[str, Deque] = dict()
        self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
        self.accessed_chunks: Set[Chunk] = set()
        self.accessed_mem: int = 0
        self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}

36
37
38
39
40
41
42
43
44
    def register_tensor(self,
                        tensor: ColoTensor,
                        group_type: str,
                        config_key: int,
                        cpu_offload: bool = False,
                        pin_memory: bool = False) -> None:
        """
        Register a tensor to the chunk manager.
        Then, the tensor should be accessed by `get_chunks`.
45
46
47

        Args:
            tensor: the tensor appended to the chunk
48
49
            group_type: the data type of the group.
            config_key: the key of the group's name, the size of the dp world
50
51
52
53
54
            cpu_offload: if True, the chunk will be closed on CPU
            pin_memory: whether the chunk is pinned in the cpu memory
        """
        assert tensor not in self.tensor_chunk_map
        assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
55
        assert config_key in self.dp_degree_chunk_size_dict
56

57
        chunk_size = self.dp_degree_chunk_size_dict[config_key]
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        chunk_kwargs = self.kwargs_config[config_key]
        group_name = "{}_{}".format(group_type, config_key)
        chunk_group = self.__get_chunk_group(group_name)

        try:
            # append the tensor to the last chunk
            chunk_group[-1].append_tensor(tensor)
        except (IndexError, ChunkFullError):
            # the except statement will be triggered when there is no chunk or
            # the last chunk in the chunk group is full
            # this will create a new chunk and allocate this chunk to its corresponding process
            if chunk_group:
                # the chunk group is not empty
                # close the last chunk
                self.__close_one_chunk(chunk_group[-1])

            if tensor.numel() > chunk_size:
                chunk_size = tensor.numel()
76
                dp_size = tensor.get_dp_world_size()
77
78
                chunk_size = chunk_size + (-chunk_size % dp_size)

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
            chunk = Chunk(
                chunk_size=chunk_size,
                process_group=tensor.process_group,
                dtype=tensor.dtype,
                cpu_shard_init=cpu_offload,
                pin_memory=pin_memory,
                **chunk_kwargs,
            )

            chunk_group.append(chunk)
            chunk.append_tensor(tensor)
            self.__add_memory_usage(chunk.memory_usage)

        self.tensor_chunk_map[tensor] = chunk_group[-1]

    def close_all_groups(self):
        """Close all the chunks of all groups.
        """
        for group_name in self.chunk_groups:
            self.__close_one_chunk(self.chunk_groups[group_name][-1])

    def access_chunk(self, chunk: Chunk) -> None:
        """Make the chunk can be used for calculation.
        """
        if chunk in self.accessed_chunks:
            return
105
        self.__sub_memory_usage(chunk.memory_usage)
106
107
108
109
110
111
112
113
114
115
116
        if chunk.device_type == 'cpu':
            chunk.shard_move(get_current_device())
        self.__add_accessed_chunk(chunk)
        self.__add_memory_usage(chunk.memory_usage)

    def release_chunk(self, chunk: Chunk) -> None:
        """Scatter the chunk in CUDA.
        """
        if chunk not in self.accessed_chunks:
            return
        if chunk.can_release:
117
            self.__sub_memory_usage(chunk.memory_usage)
118
119
120
121
122
123
124
125
            self.__sub_accessed_chunk(chunk)
            self.__add_memory_usage(chunk.memory_usage)

    def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
        """Move the shard of the chunk to the target device.
        """
        if not chunk.can_move or chunk.device_type == device.type:
            return
126
        self.__sub_memory_usage(chunk.memory_usage)
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        chunk.shard_move(device, force_copy)
        self.__add_memory_usage(chunk.memory_usage)

    def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
        """Transit tensor state according to pre-defined state machine.
        """
        chunk = self.tensor_chunk_map[tensor]
        chunk.tensor_trans_state(tensor, state)

    def reduce_chunk(self, chunk: Chunk) -> bool:
        """Reduce or all reduce the chunk.
        """
        if not chunk.can_reduce:
            return False
141
        self.__sub_memory_usage(chunk.memory_usage)
142
143
144
145
146
        chunk.reduce()
        self.__sub_accessed_chunk(chunk)
        self.__add_memory_usage(chunk.memory_usage)
        return True

147
148
149
150
151
152
153
154
    def fake_release_chunk(self, chunk: Chunk) -> None:
        """Release gathered chunk in a fake mode.
        This function is used for keep-gathered chunk in the inference mode.
        """
        assert chunk.keep_gathered
        assert chunk.tensor_state_cnter[TensorState.HOLD] == chunk.num_tensors
        self.__sub_accessed_chunk(chunk)

155
156
157
158
159
    def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
        """
        Copy data to the chunk.

        Args:
160
            tensor (torch.Tensor): the tensor used to retrieve meta information
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
            data (torch.Tensor): the tensor to be copied to the chunk
        """
        chunk = self.tensor_chunk_map[tensor]
        chunk.copy_tensor_to_chunk_slice(tensor, data)

    def get_chunk(self, tensor: torch.Tensor) -> Chunk:
        """
        Return the chunk owning the tensor.

        Args:
            tensor (torch.Tensor): a torch tensor object
        """
        return self.tensor_chunk_map[tensor]

    def get_cuda_movable_chunks(self) -> List[Chunk]:
        """
        Get all chunks that can be moved.
        """
        chunk_list = []
        for chunk in self.accessed_chunks:
            if chunk.can_release:
                chunk_list.append(chunk)
        chunk_list.sort(key=lambda x: x.count_id)
        return chunk_list

    def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
        """
        Get all chunks owning the input tensors.

        Args:
            tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
        """
        chunks = []
        for tensor in tensors:
            chunk = self.get_chunk(tensor)
            if chunk not in chunks:
                chunks.append(chunk)
        return tuple(chunks)

    def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
        """Add extern static tensor to chunk manager.
        Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
        They are "static", which means their shape, dtype, device never change.
        Thus, their memory usage never changes.

        Args:
            tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
        """
        assert tensor not in self.tensor_chunk_map
        self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()

    def __repr__(self) -> str:
        msg = [
            'Chunk Manager Information:\n',
            'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
        ]
        for group_name, group in self.chunk_groups.items():
            msg.append(f'Group {group_name}:\n')
            for i, chunk in enumerate(group):
                msg.append(f'[{i}] {chunk}\n')
        return ''.join(msg)

    def __get_chunk_group(self, group_name: str) -> Deque:
        """Register a chunk group.
        """
        if group_name not in self.chunk_groups:
            self.chunk_groups[group_name] = deque()
        return self.chunk_groups[group_name]

    def __close_one_chunk(self, chunk: Chunk):
231
        self.__sub_memory_usage(chunk.memory_usage)
232
233
234
        chunk.close_chunk()
        self.__add_memory_usage(chunk.memory_usage)

235
    def __sub_memory_usage(self, usage: Dict[str, int]):
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        for k, v in usage.items():
            self.total_mem[k] -= v

    def __add_memory_usage(self, usage: Dict[str, int]):
        for k, v in usage.items():
            self.total_mem[k] += v

    def __add_accessed_chunk(self, chunk: Chunk):
        chunk.access_chunk()
        self.accessed_chunks.add(chunk)
        self.accessed_mem += chunk.chunk_mem

    def __sub_accessed_chunk(self, chunk: Chunk):
        chunk.release_chunk()
        self.accessed_chunks.remove(chunk)
        self.accessed_mem -= chunk.chunk_mem