chunk_codegen.py 39.6 KB
Newer Older
oahzxl's avatar
init  
oahzxl committed
1
2
import colossalai
import torch
oahzxl's avatar
oahzxl committed
3
import copy
oahzxl's avatar
init  
oahzxl committed
4
5
from typing import List, Callable, Any, Tuple, Dict, Iterable

oahzxl's avatar
oahzxl committed
6
7
8
9
10
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size
CODEGEN_AVAILABLE = True
__all__ = ['ChunkCodeGen']
oahzxl's avatar
init  
oahzxl committed
11
12


oahzxl's avatar
oahzxl committed
13
14
15
16
17
18
class NodeIndexTracer(object):
    def __init__(self, gm) -> None:
        self.gm = gm
        self.nodes_list = list(gm.graph.nodes)
        self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))] 
        self.idx_trace_equal = []
oahzxl's avatar
oahzxl committed
19
        self.idx_view_list = []
oahzxl's avatar
oahzxl committed
20
        self.idx_count = -1
oahzxl's avatar
oahzxl committed
21
22

    def add_index(self):
oahzxl's avatar
oahzxl committed
23
24
25
26
27
28
        """
        Update the count and return it. To record the idx number.
        
        Returns:
            idx_count: int
        """        
oahzxl's avatar
oahzxl committed
29
        self.idx_count += 1
oahzxl's avatar
oahzxl committed
30
        return self.idx_count
oahzxl's avatar
oahzxl committed
31
32

    def inherit_computation(self, node_from, node_to):
oahzxl's avatar
oahzxl committed
33
34
35
36
37
38
39
40
41
        """
        Inherit computed dim from node_from to node_to.
        If a dim in node_from is marked as computed and exists in node_to,
        still mark it as computed in node_to.

        Args:
            node_from (node): node to be inherited
            node_to (node): new node to inherit
        """        
oahzxl's avatar
oahzxl committed
42
43
44
        _, compute_from = self.find_trace_from_node(node_from)
        idx_to, compute_to = self.find_trace_from_node(node_to)
        for i in compute_from:
oahzxl's avatar
oahzxl committed
45
            if i in idx_to and i not in compute_to:
oahzxl's avatar
oahzxl committed
46
47
48
                compute_to.append(i)
    
    def mark_idx_equal(self, idx1, idx2):
oahzxl's avatar
oahzxl committed
49
50
51
52
53
54
55
        """
        Mark 2 index to be equal.

        Args:
            idx1 (int): index count.
            idx2 (int): index count.
        """        
oahzxl's avatar
oahzxl committed
56
57
58
        self.idx_trace_equal.append((idx1, idx2))
        
    def mark_computation(self, node, idx, dim):
oahzxl's avatar
oahzxl committed
59
60
61
62
63
64
65
66
        """
        Mark some dims of node as computed.

        Args:
            node (node)
            idx (int): node index
            dim (list or int): dims to be marked as computed
        """        
oahzxl's avatar
oahzxl committed
67
68
69
70
71
        input_node_idx_trace = self.find_idx_trace_from_node(node)
        if isinstance(dim, int):
            dim = [dim]
        for d in dim:
            cur_idx = input_node_idx_trace[d]
oahzxl's avatar
oahzxl committed
72
73
            if cur_idx not in self.idx_trace_list[idx]['compute']:
                self.idx_trace_list[idx]['compute'].append(cur_idx)
oahzxl's avatar
oahzxl committed
74
75
    
    def find_trace_from_node(self, node):
oahzxl's avatar
oahzxl committed
76
77
78
79
80
81
82
83
84
        """
        Find node idx and compute trace by the node.

        Args:
            node (node)
        Returns:
            idx (list): idx of the node
            compute (list): computed idx of the node.
        """        
oahzxl's avatar
oahzxl committed
85
86
87
88
89
        node_idx = _find_idx_by_name(node.name, self.nodes_list)
        node_dict = self.idx_trace_list[node_idx]
        return node_dict['idx'], node_dict['compute']
    
    def find_idx_trace_from_node(self, node):
oahzxl's avatar
oahzxl committed
90
91
92
93
94
95
96
97
        """
        Find node idx trace by the node.

        Args:
            node (node)
        Returns:
            idx (list): idx of the node
        """ 
oahzxl's avatar
oahzxl committed
98
        node_idx = _find_idx_by_name(node.name, self.nodes_list)
oahzxl's avatar
oahzxl committed
99
100
101
        return self.idx_trace_list[node_idx]['idx']
    
    def find_compute_trace_from_node(self, node):
oahzxl's avatar
oahzxl committed
102
103
104
105
106
107
108
109
        """
        Find node compute trace by the node.

        Args:
            node (node)
        Returns:
            compute (list): computed idx of the node.
        """ 
oahzxl's avatar
oahzxl committed
110
111
        node_idx = _find_idx_by_name(node.name, self.nodes_list)
        return self.idx_trace_list[node_idx]['compute']
oahzxl's avatar
oahzxl committed
112
113
    
    def assign_index_as_input(self, node, node_idx):
oahzxl's avatar
oahzxl committed
114
115
116
117
118
119
120
        """
        Assign node's trace as its input node.

        Args:
            node (node)
            node_idx (int)
        """        
oahzxl's avatar
oahzxl committed
121
122
123
124
125
126
127
        input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
        input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx']
        
        new_idx_trace = copy.deepcopy(input_node_idx_trace)
        self.idx_trace_list[node_idx]['idx'] = new_idx_trace
    
    def assign_all_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
128
129
130
131
132
133
134
        """
        Add new index for all node's dims.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
135
136
137
138
139
140
141
        shape = node.meta['tensor_meta'].shape
        new_trace = []
        for _ in shape:
            new_trace.append(self.add_index())
        self.idx_trace_list[node_idx]['idx'] = new_trace   

    def assign_transpose_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
142
143
144
145
146
147
148
149
150
        """
        Assign index for transpose op.
        1. swap input's dim according to transpose args
        2. inherit input's computation

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
151
152
153
154
155
156
157
158
        tranpose_dim = node.args[1:]
        input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
        
        new_idx_trace = copy.deepcopy(input_node_idx_trace)
        new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]]
        new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]]

        self.idx_trace_list[node_idx]['idx'] = new_idx_trace
oahzxl's avatar
oahzxl committed
159
160
161
        self.inherit_computation(node.args[0], node)
        
    def assign_permute_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
162
163
164
165
166
167
168
169
170
        """
        Assign index for permute op.
        1. swap input's dim according to permute args
        2. inherit input's computation

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
171
172
173
174
175
176
177
178
179
        permute_dim = node.args[1:]
        input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
        
        new_idx_trace = copy.deepcopy(input_node_idx_trace)
        for idx, d in enumerate(permute_dim):
            new_idx_trace[idx] = input_node_idx_trace[d]

        self.idx_trace_list[node_idx]['idx'] = new_idx_trace
        self.inherit_computation(node.args[0], node)
oahzxl's avatar
oahzxl committed
180
181
        
    def assign_linear_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
182
183
184
185
186
187
188
189
190
191
        """
        Assign index for linear op.
        1. copy trace from input node and change last index accroding to weight
        2. mark equal for input node last index, weight first dim and bias dim.
        3. inherit input's computation, mark computation for last dim.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        input_node, weight, bias = node.args
        input_node_idx_trace = self.find_idx_trace_from_node(input_node)
        weight_idx_trace = self.find_idx_trace_from_node(weight)
        
        new_idx_trace = copy.deepcopy(input_node_idx_trace)
        new_idx_trace[-1] = weight_idx_trace[1]
        self.idx_trace_list[node_idx]['idx'] = new_idx_trace

        self.inherit_computation(input_node, node)
        self.mark_computation(node, node_idx, [-1])
        self.mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0])
        
        if bias:
            bias_idx_trace = self.find_idx_trace_from_node(bias)
            self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])

oahzxl's avatar
oahzxl committed
208
    def assign_matmul_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
209
210
211
212
213
214
215
216
217
218
        """
        Assign index for matmul op.
        1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
        2. mark equal for input matmul_left -1 index and matmul_right -2 dim.
        3. inherit matmul_left and matmul_right computation, mark computation for last dim.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        matmul_left, matmul_right = node.args
        matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left)
        matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right)
        
        assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace))
        new_idx_trace = copy.deepcopy(matmul_left_idx_trace)
        new_idx_trace[-1] = matmul_right_idx_trace[-1]
        self.idx_trace_list[node_idx]['idx'] = new_idx_trace

        self.inherit_computation(matmul_left, node)
        self.inherit_computation(matmul_right, node)
        self.mark_computation(node, node_idx, [-1])
        self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])

oahzxl's avatar
oahzxl committed
233
    def assign_layernorm_index(self, node, idx):
oahzxl's avatar
oahzxl committed
234
235
236
237
238
239
240
241
242
        """
        Assign index for layernorm op.
        1. assign index as input node
        2. inherit computation and mark last 2 dims as computed.

        Args:
            node (node)
            node_idx (int)
        """
oahzxl's avatar
oahzxl committed
243
        self.assign_index_as_input(node, idx)
oahzxl's avatar
oahzxl committed
244
        self.inherit_computation(node.args[0], node)
oahzxl's avatar
oahzxl committed
245
        self.mark_computation(node, idx, [-1, -2])
oahzxl's avatar
oahzxl committed
246
247
    
    def assign_elementwise_index(self, node, idx):
oahzxl's avatar
oahzxl committed
248
249
250
251
252
253
254
255
256
        """
        Assign index for element-wise op (eg. relu sigmoid add mul).
        1. assign index as input node
        2. inherit computation from all input nodes.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
257
258
259
260
261
262
        self.assign_index_as_input(node, idx)
        for node_in in node.args:
            if type(node_in) not in (int, float):
                self.inherit_computation(node_in, node)
                
    def assign_softmax_index(self, node, idx):
oahzxl's avatar
oahzxl committed
263
264
265
266
267
268
269
270
271
        """
        Assign index for softmax op.
        1. assign index as input node
        2. inherit computation and mark softmax dim as computed.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
272
        self.assign_index_as_input(node, idx)
oahzxl's avatar
oahzxl committed
273
        self.inherit_computation(node.args[0], node)
oahzxl's avatar
oahzxl committed
274
275
276
        self.mark_computation(node, idx, [node.kwargs['dim']])

    def assign_view_reshape_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
277
278
279
280
281
282
        """
        Assign index for view and reshape op.
        1. get origin shape and target shape by meta info.
        2. compute the real value of -1 in target shape.
        3. determine changed dim, and assgin index for generated dim.
        4. log changed dim and generated dim for restore
oahzxl's avatar
oahzxl committed
283
284
        5. inherit computation.
        6. TODO: look into view list to see whether the view is associated with other,
oahzxl's avatar
oahzxl committed
285
286
287
288
289
290
           if so assgin equal dim according to previous view.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        # get data, turn into number
        origin_node = node.args[0]
        origin_shape = origin_node.meta['tensor_meta'].shape
        target_shape = []
        for i in range(1, len(node.args)):
            if isinstance(node.args[i], int):
                target_shape.append(node.args[i])
            else:
                target_shape.append(node.args[i].meta['fwd_out'][0])

        # compute the value of -1
        if -1 in target_shape:
            origin_product = 1
            for i in origin_shape:
                origin_product *= i
            target_product = -1
            for i in target_shape:
                target_product *= i
            shape_idx = target_shape.index(-1)
            target_shape[shape_idx] = origin_product // target_product

        # determine changed dim
        len_diff = len(origin_shape) - len(target_shape)
        if len_diff == 1:
            # dim merge
            dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
            dim_to = [dim_equal.index(False)]
            dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
        elif len_diff == -1:
            # dim expand
            dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
            dim_from = [dim_equal.index(False)]
            dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
        else:
            raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented")

        # get new index
        origin_trace = self.find_idx_trace_from_node(origin_node)
        new_trace = copy.deepcopy(origin_trace)
        dim_from.reverse()
        for i in dim_from:
            new_trace.pop(i)
        for i in dim_to:
            new_trace.insert(i, self.add_index())
        self.idx_trace_list[node_idx]['idx'] = new_trace
        
        # inherit computation
        self.inherit_computation(origin_node, node)
        compute_log = self.find_compute_trace_from_node(origin_node)
        for i in dim_from:
            if origin_trace[i] in compute_log:
                for j in dim_to:
                    self.mark_computation(node, node_idx, [j])
                break
        
oahzxl's avatar
oahzxl committed
346
        # log view, not used now
oahzxl's avatar
oahzxl committed
347
348
349
350
351
352
        view_dict = {"idx_from": [origin_trace[i] for i in dim_from],
                     "dim_from": dim_from,
                     "idx_to": [new_trace[i] for i in dim_to],
                     "dim_to": dim_to}
        self.idx_view_list.append(view_dict) 
        
oahzxl's avatar
oahzxl committed
353
354
355
356
357
358
359
360
    def trace_node_idx(self):
        for idx, node in enumerate(self.nodes_list):
            if node.op == 'placeholder':
                self.assign_all_index(node, idx)
            elif node.op == 'call_method':
                if 'transpose' in node.name:
                    self.assign_transpose_index(node, idx)
                elif 'permute' in node.name:
oahzxl's avatar
oahzxl committed
361
362
363
                    self.assign_permute_index(node, idx)
                elif 'view' in node.name or 'reshape' in node.name:
                    self.assign_view_reshape_index(node, idx)
oahzxl's avatar
oahzxl committed
364
365
366
367
368
                else:
                    raise NotImplementedError(node.name, "method not implemented yet!")
            elif node.op == 'call_function':
                if 'linear' in node.name:
                    self.assign_linear_index(node, idx)
oahzxl's avatar
oahzxl committed
369
370
371
372
373
374
                elif 'matmul' in node.name:
                    self.assign_matmul_index(node, idx)
                elif 'softmax' in node.name:
                    self.assign_softmax_index(node, idx)
                elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
                    self.assign_elementwise_index(node, idx)
oahzxl's avatar
oahzxl committed
375
376
377
378
379
380
381
                elif 'getattr' in node.name:
                    continue # get attr like shape
                elif 'getitem' in node.name:
                    continue # get item in list
                else:
                    raise NotImplementedError(node.name, "function not implemented yet!")
            elif node.op == 'call_module':
oahzxl's avatar
oahzxl committed
382
                if any(n in node.name for n in ['layernorm', 'norm']):
oahzxl's avatar
oahzxl committed
383
384
385
386
387
                    self.assign_layernorm_index(node, idx)
                else:
                    raise NotImplementedError(node.name, "module not implemented yet!")
            elif node.op == 'get_attr':
                self.assign_all_index(node, idx) # get param
oahzxl's avatar
oahzxl committed
388
389
            elif node.op == 'output':
                continue
oahzxl's avatar
oahzxl committed
390
391
392
            else:
                raise NotImplementedError(node.op, "op not implemented yet!")

oahzxl's avatar
oahzxl committed
393

oahzxl's avatar
oahzxl committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def _get_meta_node_size(x):
    x = x.meta['tensor_meta']
    x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
    return x


def _get_output_node_size(n):
    fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
    return activation_size(fwd_out)


def _get_delete_node_size(user, user_to_last_uses):
    if user.op in ('placeholder', 'output'):
        return 0
    nodes_to_delete = user_to_last_uses.get(user, [])
    if len(nodes_to_delete):
        delete_size = sum([_get_output_node_size(i) for i in nodes_to_delete])
        return delete_size
    return 0


def _get_last_usr(nodes):
    node_to_last_use: Dict[Node, Node] = {}
    user_to_last_uses: Dict[Node, List[Node]] = {}

    def register_last_uses(n: Node, user: Node):
        if n not in node_to_last_use:
            node_to_last_use[n] = user
            user_to_last_uses.setdefault(user, []).append(n)

    for node in reversed(nodes):
        map_arg(node.args, lambda n: register_last_uses(n, node))
        map_arg(node.kwargs, lambda n: register_last_uses(n, node))
    return user_to_last_uses


oahzxl's avatar
oahzxl committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
def _delete_free_var_from_last_use(user_to_last_uses):
    for key, value in user_to_last_uses.items():
        for n in value:
            if n.op == 'placeholder':
                user_to_last_uses[key].remove(n)


def _get_contiguous_memory(node, not_contiguous_list, delete=False):
    mem = 0
    not_contiguous_ops = ['transpose', 'permute']

    if node.op == 'call_function' and 'matmul' in node.name:
        for n in node.args:
            if n in not_contiguous_list:
                # matmul won't change origin tensor, but create a tmp copy
                mem += _get_output_node_size(n)
    elif node.op == 'call_module':
        for n in node.args:
            if n in not_contiguous_list:
                # module will just make origin tensor to contiguous
                if delete:
                    not_contiguous_list.remove(n)
    elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops):
        if node not in not_contiguous_list:
            not_contiguous_list.append(node)
    elif any(i in node.args for i in not_contiguous_list):
        if node not in not_contiguous_list:
            not_contiguous_list.append(node)

    return mem


oahzxl's avatar
oahzxl committed
462
def _estimate_inference_mem(gm: torch.fx.GraphModule):
oahzxl's avatar
oahzxl committed
463
    act_memory = 0.0
oahzxl's avatar
oahzxl committed
464
465
    act_memory_peak_log = []
    act_memory_after_node_log = []
oahzxl's avatar
oahzxl committed
466
    not_contiguous_list = []
oahzxl's avatar
oahzxl committed
467
    user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
oahzxl's avatar
oahzxl committed
468
    _delete_free_var_from_last_use(user_to_last_uses)
oahzxl's avatar
oahzxl committed
469
470
471
    for node in gm.graph.nodes:
        # if node is placeholder, just add the size of the node
        if node.op == 'placeholder':
oahzxl's avatar
oahzxl committed
472
            act_memory += _get_meta_node_size(node) / (1024 ** 2)
473
474
            act_memory_peak_log.append(act_memory)
            act_memory_after_node_log.append(act_memory)
oahzxl's avatar
oahzxl committed
475
476
477
478
479
480
        # skip output
        elif node.op == 'output':
            continue
        # node is an operation, calculate tmp, output node and delete node memory
        else:
            # forward memory
oahzxl's avatar
oahzxl committed
481
482
            act_memory += _get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2)
            act_memory += _get_output_node_size(node) / (1024 ** 2)
oahzxl's avatar
oahzxl committed
483
484
485
            # record max act memory
            act_memory_peak_log.append(act_memory)
            # delete useless memory
oahzxl's avatar
oahzxl committed
486
487
            act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
            act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
oahzxl's avatar
oahzxl committed
488
489
            act_memory_after_node_log.append(act_memory)

oahzxl's avatar
oahzxl committed
490
    print("no chunk")
oahzxl's avatar
oahzxl committed
491
492
    _print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
    _print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
493
    
oahzxl's avatar
oahzxl committed
494
    param_memory = parameter_size(gm)
oahzxl's avatar
oahzxl committed
495
    return act_memory + param_memory, param_memory
oahzxl's avatar
oahzxl committed
496
497


oahzxl's avatar
oahzxl committed
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def _get_chunk_ratio(node, chunk_dim, chunk_size):
    shape = node.meta['tensor_meta'].shape
    chunk_ratio = float(chunk_size) / shape[chunk_dim]
    return chunk_ratio


def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node):
    if user.op in ('placeholder', 'output'):
        return 0
    nodes_to_delete = user_to_last_uses.get(user, [])
    delete_size = 0
    for n in nodes_to_delete:
        node_idx = _find_idx_by_name(n.name, node_list)
        if start_node <= node_idx < end_node:
            delete_size += _get_output_node_size(n) * chunk_ratio
    return delete_size


oahzxl's avatar
oahzxl committed
516
def _print_mem_log(log, nodes, title=None):
oahzxl's avatar
oahzxl committed
517
    if title:
oahzxl's avatar
oahzxl committed
518
519
520
521
522
523
        print(title)
    for idx, (l, n) in enumerate(zip(log, nodes)):
        print("%s:%.2f \t" % (n.name, l), end='')
        if (idx + 1) % 3 == 0:
            print("")
    print("\n")
oahzxl's avatar
oahzxl committed
524
525
526


def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes):
oahzxl's avatar
oahzxl committed
527
    act_memory = 0.0
oahzxl's avatar
oahzxl committed
528
529
    act_memory_peak_log = []
    act_memory_after_node_log = []
oahzxl's avatar
oahzxl committed
530
    not_contiguous_list = []
oahzxl's avatar
oahzxl committed
531
    user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
oahzxl's avatar
oahzxl committed
532
    _delete_free_var_from_last_use(user_to_last_uses)
oahzxl's avatar
oahzxl committed
533
534
535
536
537
538
539
540
541
542
    within_chunk = False
    region_idx = 0
    chunk_ratio = 1 # use it to estimate chunk mem
    node_list = list(gm.graph.nodes)

    for idx, node in enumerate(node_list):
        # if node in chunk start nodes, change chunk ratio and add chunk_tensor
        if idx in start_nodes:
            within_chunk = True
            chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx])
oahzxl's avatar
oahzxl committed
543
            act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2)
oahzxl's avatar
oahzxl committed
544
545
546
            
        # if node is placeholder, just add the size of the node
        if node.op == 'placeholder':
oahzxl's avatar
oahzxl committed
547
            act_memory += _get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
oahzxl's avatar
oahzxl committed
548
549
550
551
552
553
554
            act_memory_peak_log.append(act_memory)
        # skip output
        elif node.op == 'output':
            continue
        # node is an operation, calculate tmp, output node and delete node memory
        else:
            # forward memory
oahzxl's avatar
oahzxl committed
555
            # TODO: permute will create a tmp copy if not contiguous
oahzxl's avatar
oahzxl committed
556
557
            act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
            act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
oahzxl's avatar
oahzxl committed
558
559
560
            # record max act memory
            act_memory_peak_log.append(act_memory)
            # delete useless memory
oahzxl's avatar
oahzxl committed
561
            act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
oahzxl's avatar
oahzxl committed
562
563
            if within_chunk:
                act_memory -= _get_chunk_delete_node_size(
oahzxl's avatar
oahzxl committed
564
565
                    node, user_to_last_uses, chunk_ratio, node_list, 
                    start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2)
oahzxl's avatar
oahzxl committed
566
            else:
oahzxl's avatar
oahzxl committed
567
                act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
oahzxl's avatar
oahzxl committed
568
569
            
        if idx in end_nodes:
oahzxl's avatar
oahzxl committed
570
            act_memory -= _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
oahzxl's avatar
oahzxl committed
571
572
573
574
575
576
577
            within_chunk = False
            chunk_ratio = 1
            region_idx += 1
        
        act_memory_after_node_log.append(act_memory)

    print("chunk")
oahzxl's avatar
oahzxl committed
578
579
580
    _print_mem_log(act_memory_peak_log, node_list, "peak")
    _print_mem_log(act_memory_after_node_log, node_list, "after")

oahzxl's avatar
oahzxl committed
581
    param_memory = parameter_size(gm)
oahzxl's avatar
oahzxl committed
582
    return act_memory + param_memory, param_memory
oahzxl's avatar
oahzxl committed
583
584


oahzxl's avatar
oahzxl committed
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
    new_shape = "["
    for idx, i in enumerate(shape):
        if idx == chunk_dim:
            new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name)
        else:
            new_shape += ":"
        new_shape += ", "
    new_shape = new_shape[:-2] + "]"
    return new_shape


def _get_first_non_single_dim(shape):
    for idx, i in enumerate(shape):
        if i == 1:
            continue
        else:
            return idx
    raise RuntimeError("can not get first non single dim for shape", shape)


def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2):
    if len(chunk_input_meta) == 1:
        node = chunk_input_meta[0]
        node_shape = node.meta['tensor_meta'].shape
        chunk_dim = _get_first_non_single_dim(node_shape)
        chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
        out_shape = str(list(chunk_output.meta['tensor_meta'].shape))
        
        context = "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" % (
            out_shape, node.name, node.name, chunk_size)
        context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
        context += "    chunk_tensor = %s%s\n" % (node.name, chunk_slice)
    else:
        raise NotImplementedError("input with size %d not implemented" % len(chunk_input_meta))
oahzxl's avatar
oahzxl committed
620
621
622
    return context


oahzxl's avatar
oahzxl committed
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
def _gen_loop_end(chunk_outputs, chunk_inputs, node_list):
    chunk_inputs_name = chunk_inputs[0].name
    chunk_outputs_name = chunk_outputs.name
    chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
    chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape
    chunk_dim = _get_first_non_single_dim(chunk_output_shape)
    chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape)
    context = "    chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)

    context += chunk_outputs_name + " = chunk_result;  chunk_result = None;  chunk_size = None"
    
    # determine if its the last use for chunk input
    users_name = list(chunk_inputs[0].users.keys())
    if all([_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in users_name]):
        context += ";  %s = None" % chunk_inputs_name

    context += "\n"
oahzxl's avatar
oahzxl committed
640
641
    return context

oahzxl's avatar
init  
oahzxl committed
642
643
644
645
646
647
648
649
650
651
652
653
654
655

def _find_input_and_output_nodes(nodes: List[Node]):
    """
    Find the input and output node names which are not found in the given list of nodes.
    """
    input_nodes = []
    output_nodes = []

    # if a node has an input node which is not in the node list
    # we treat that input node as the input of the checkpoint function
    for node in nodes:
        for input_node in node._input_nodes.keys():
            node_repr = repr(input_node)
            if input_node not in nodes and node_repr not in input_nodes:
oahzxl's avatar
oahzxl committed
656
                input_nodes.append(input_node)
oahzxl's avatar
init  
oahzxl committed
657
658
659
660
661
662
663

    # if a node has a user node which is not in the node list
    # we treat that user node as the node receiving the current node output
    for node in nodes:
        for output_node in node.users.keys():
            node_repr = repr(node)
            if output_node not in nodes and node_repr not in output_nodes:
oahzxl's avatar
oahzxl committed
664
                output_nodes.append(output_node)
oahzxl's avatar
init  
oahzxl committed
665
666
667
668

    return input_nodes, output_nodes


oahzxl's avatar
oahzxl committed
669
670
671
672
673
def _find_idx_by_name(name, nodes_list):
    for idx, node in enumerate(nodes_list):
        if node.name == name:
            return idx
    raise RuntimeError("name %s not found in node list" % name)
oahzxl's avatar
init  
oahzxl committed
674
675


oahzxl's avatar
oahzxl committed
676
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph):
oahzxl's avatar
init  
oahzxl committed
677
678
679
680
681
682
683
684
685
686
687
688
689
    """Emit code with nested activation checkpoint
    When we detect some of the node.activation_checkpoint is a List, we will use
    this function to emit the activation checkpoint codes.

    Args:
        body: forward code
        ckpt_func: checkpoint functions code
        nodes: graph.nodes
        emit_node_func: function to emit node
        delete_unused_value_func: function to remove the unused value
    """

    # find the offload regions
oahzxl's avatar
oahzxl committed
690
    chunk_regions = [(58, 62)]
oahzxl's avatar
oahzxl committed
691
692
693
694
695
    chunk_starts = [item[0] for item in chunk_regions]
    chunk_ends = [item[1] for item in chunk_regions]
    chunk_inputs = []
    chunk_outputs = []
    within_chunk_region = False
oahzxl's avatar
init  
oahzxl committed
696
697

    node_list = list(nodes)
oahzxl's avatar
oahzxl committed
698
    _estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
oahzxl's avatar
oahzxl committed
699
    _estimate_inference_mem(meta_graph)
oahzxl's avatar
oahzxl committed
700
701
    node_index_tracer = NodeIndexTracer(meta_graph)
    node_index_tracer.trace_node_idx()
oahzxl's avatar
init  
oahzxl committed
702
703

    # find the input and output var names for each offload region
oahzxl's avatar
oahzxl committed
704
    for idx, (start, end) in enumerate(chunk_regions):
oahzxl's avatar
init  
oahzxl committed
705
706
        offload_node_list = node_list[start:end + 1]
        inputs, outputs = _find_input_and_output_nodes(offload_node_list)
oahzxl's avatar
oahzxl committed
707
708
        chunk_inputs.append(inputs)
        chunk_outputs.append(outputs)
oahzxl's avatar
oahzxl committed
709
710
711
712
713
714
715
    chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs]
    chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs]
    chunk_inputs_names = []
    for i in chunk_inputs:
        for j in i:
            chunk_inputs_names.append(j.name)
    
oahzxl's avatar
init  
oahzxl committed
716
717
718
    # this flag is to prevent repeated insert of save tensors
    # hooks definition in ckpt_func
    node_idx = 0
oahzxl's avatar
oahzxl committed
719
    region_idx = 0
oahzxl's avatar
oahzxl committed
720
    while node_idx < len(node_list):
oahzxl's avatar
oahzxl committed
721
        node = node_list[node_idx]
oahzxl's avatar
init  
oahzxl committed
722

oahzxl's avatar
oahzxl committed
723
724
725
726
727
728
        if node_idx in chunk_starts:
            within_chunk_region = True
                
            # add for loop
            chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
            body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]]))
oahzxl's avatar
init  
oahzxl committed
729

oahzxl's avatar
oahzxl committed
730
731
732
        if within_chunk_region:
            emit_node_func(node, body)
            # replace input var with chunk var
oahzxl's avatar
oahzxl committed
733
            if node_idx in chunk_starts:
oahzxl's avatar
oahzxl committed
734
735
736
                body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)')
            body[-1] = '    ' + body[-1]
            delete_unused_value_func(node, body, chunk_inputs_names)
oahzxl's avatar
init  
oahzxl committed
737

oahzxl's avatar
oahzxl committed
738
739
740
741
        else:
            emit_node_func(node, body)
            if node_idx not in chunk_inputs:
                delete_unused_value_func(node, body, chunk_inputs_names)
oahzxl's avatar
init  
oahzxl committed
742

oahzxl's avatar
oahzxl committed
743
744
745
746
        if node_idx in chunk_ends:
            body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list))
            within_chunk_region = False
            region_idx += 1
oahzxl's avatar
init  
oahzxl committed
747

oahzxl's avatar
oahzxl committed
748
        node_idx += 1
oahzxl's avatar
init  
oahzxl committed
749
750
751
752


if CODEGEN_AVAILABLE:

oahzxl's avatar
oahzxl committed
753
    class ChunkCodeGen(CodeGen):
oahzxl's avatar
oahzxl committed
754
755
        def __init__(self, meta_graph):
            super().__init__()
oahzxl's avatar
oahzxl committed
756
            self.meta_graph = meta_graph
oahzxl's avatar
oahzxl committed
757
            self.meta_node = list(meta_graph.graph.nodes)
oahzxl's avatar
init  
oahzxl committed
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858

        def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
            free_vars: List[str] = []
            body: List[str] = []
            globals_: Dict[str, Any] = {}
            wrapped_fns: Dict[str, None] = {}

            # Wrap string in list to pass by reference
            maybe_return_annotation: List[str] = ['']

            def add_global(name_hint: str, obj: Any):
                """Add an obj to be tracked as a global.

                We call this for names that reference objects external to the
                Graph, like functions or types.

                Returns: the global name that should be used to reference 'obj' in generated source.
                """
                if _is_from_torch(obj) and obj != torch.device:    # to support registering torch.device
                    # HACK: workaround for how torch custom ops are registered. We
                    # can't import them like normal modules so they must retain their
                    # fully qualified name.
                    return _get_qualified_name(obj)

                # normalize the name hint to get a proper identifier
                global_name = namespace.create_name(name_hint, obj)

                if global_name in globals_:
                    assert globals_[global_name] is obj
                    return global_name
                globals_[global_name] = obj
                return global_name

            # set _custom_builtins here so that we needn't import colossalai in forward
            _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)

            # Pre-fill the globals table with registered builtins.
            for name, (_, obj) in _custom_builtins.items():
                add_global(name, obj)

            def type_repr(o: Any):
                if o == ():
                    # Empty tuple is used for empty tuple type annotation Tuple[()]
                    return '()'

                typename = _type_repr(o)

                if hasattr(o, '__origin__'):
                    # This is a generic type, e.g. typing.List[torch.Tensor]
                    origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
                    origin_typename = add_global(_type_repr(origin_type), origin_type)

                    if hasattr(o, '__args__'):
                        # Assign global names for each of the inner type variables.
                        args = [type_repr(arg) for arg in o.__args__]

                        if len(args) == 0:
                            # Bare type, such as `typing.Tuple` with no subscript
                            # This code-path used in Python < 3.9
                            return origin_typename

                        return f'{origin_typename}[{",".join(args)}]'
                    else:
                        # Bare type, such as `typing.Tuple` with no subscript
                        # This code-path used in Python 3.9+
                        return origin_typename

                # Common case: this is a regular module name like 'foo.bar.baz'
                return add_global(typename, o)

            def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:

                def _get_repr(arg):
                    # Handle NamedTuples (if it has `_fields`) via add_global.
                    if isinstance(arg, tuple) and hasattr(arg, '_fields'):
                        qualified_name = _get_qualified_name(type(arg))
                        global_name = add_global(qualified_name, type(arg))
                        return f"{global_name}{repr(tuple(arg))}"
                    return repr(arg)

                args_s = ', '.join(_get_repr(a) for a in args)
                kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
                if args_s and kwargs_s:
                    return f'{args_s}, {kwargs_s}'
                return args_s or kwargs_s

            # Run through reverse nodes and record the first instance of a use
            # of a given node. This represents the *last* use of the node in the
            # execution order of the program, which we will use to free unused
            # values
            node_to_last_use: Dict[Node, Node] = {}
            user_to_last_uses: Dict[Node, List[Node]] = {}

            def register_last_uses(n: Node, user: Node):
                if n not in node_to_last_use:
                    node_to_last_use[n] = user
                    user_to_last_uses.setdefault(user, []).append(n)

            for node in reversed(nodes):
                map_arg(node.args, lambda n: register_last_uses(n, node))
                map_arg(node.kwargs, lambda n: register_last_uses(n, node))
oahzxl's avatar
oahzxl committed
859
860
861
            
            _delete_free_var_from_last_use(user_to_last_uses)
            
oahzxl's avatar
init  
oahzxl committed
862
            # NOTE: we add a variable to distinguish body and ckpt_func
oahzxl's avatar
oahzxl committed
863
            def delete_unused_values(user: Node, body, to_keep=[]):
oahzxl's avatar
init  
oahzxl committed
864
865
866
867
868
869
870
871
872
873
874
                """
                Delete values after their last use. This ensures that values that are
                not used in the remainder of the code are freed and the memory usage
                of the code is optimal.
                """
                if user.op == 'placeholder':
                    return
                if user.op == 'output':
                    body.append('\n')
                    return
                nodes_to_delete = user_to_last_uses.get(user, [])
oahzxl's avatar
oahzxl committed
875
                nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
oahzxl's avatar
init  
oahzxl committed
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
                if len(nodes_to_delete):
                    to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
                    body.append(f';  {to_delete_str}\n')
                else:
                    body.append('\n')

            # NOTE: we add a variable to distinguish body and ckpt_func
            def emit_node(node: Node, body):
                maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
                if node.op == 'placeholder':
                    assert isinstance(node.target, str)
                    maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
                    free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
                    raw_name = node.target.replace('*', '')
                    if raw_name != repr(node):
                        body.append(f'{repr(node)} = {raw_name}\n')
                    return
                elif node.op == 'call_method':
                    assert isinstance(node.target, str)
                    body.append(
                        f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
                        f'({_format_args(node.args[1:], node.kwargs)})')
                    return
                elif node.op == 'call_function':
                    assert callable(node.target)
                    # pretty print operators
                    if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
                        assert isinstance(node.args, tuple)
                        body.append(f'{repr(node)}{maybe_type_annotation} = '
                                    f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
                        return

                    # pretty print inplace operators; required for jit.script to work properly
                    # not currently supported in normal FX graphs, but generated by torchdynamo
                    if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
                        body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))};  '
                                    f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
                        return

                    qualified_name = _get_qualified_name(node.target)
                    global_name = add_global(qualified_name, node.target)
                    # special case for getattr: node.args could be 2-argument or 3-argument
                    # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
                    if global_name == 'getattr' and \
                    isinstance(node.args, tuple) and \
                    isinstance(node.args[1], str) and \
                    node.args[1].isidentifier() and \
                    len(node.args) == 2:
                        body.append(
                            f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
                        return
                    body.append(
                        f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
                    if node.meta.get('is_wrapped', False):
                        wrapped_fns.setdefault(global_name)
                    return
                elif node.op == 'call_module':
                    assert isinstance(node.target, str)
                    body.append(f'{repr(node)}{maybe_type_annotation} = '
                                f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
                    return
                elif node.op == 'get_attr':
                    assert isinstance(node.target, str)
                    body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
                    return
                elif node.op == 'output':
                    if node.type is not None:
                        maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
                    body.append(self.generate_output(node.args[0]))
                    return
                raise NotImplementedError(f'node: {node.op} {node.target}')

            # Modified for activation checkpointing
            ckpt_func = []

            # if any node has a list of labels for activation_checkpoint, we
            # will use nested type of activation checkpoint codegen
oahzxl's avatar
oahzxl committed
953
            emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node, self.meta_graph)
oahzxl's avatar
init  
oahzxl committed
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984

            if len(body) == 0:
                # If the Graph has no non-placeholder nodes, no lines for the body
                # have been emitted. To continue to have valid Python code, emit a
                # single pass statement
                body.append('pass\n')

            if len(wrapped_fns) > 0:
                wrap_name = add_global('wrap', torch.fx.wrap)
                wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
            else:
                wrap_stmts = ''

            if self._body_transformer:
                body = self._body_transformer(body)

            for name, value in self.additional_globals():
                add_global(name, value)

            # as we need colossalai.utils.checkpoint, we need to import colossalai
            # in forward function
            prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
            prologue = ''.join(ckpt_func) + prologue
            prologue = prologue

            code = ''.join(body)
            code = '\n'.join('    ' + line for line in code.split('\n'))
            fn_code = f"""
{wrap_stmts}

{prologue}
oahzxl's avatar
oahzxl committed
985
986
{code}"""   
            print(fn_code)
oahzxl's avatar
init  
oahzxl committed
987
            return PythonCode(fn_code, globals_)