chunk_codegen.py 44.8 KB
Newer Older
oahzxl's avatar
init  
oahzxl committed
1
2
import colossalai
import torch
oahzxl's avatar
oahzxl committed
3
import copy
oahzxl's avatar
init  
oahzxl committed
4
5
from typing import List, Callable, Any, Tuple, Dict, Iterable

oahzxl's avatar
oahzxl committed
6
7
8
9
10
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size
CODEGEN_AVAILABLE = True
__all__ = ['ChunkCodeGen']
oahzxl's avatar
init  
oahzxl committed
11
12


oahzxl's avatar
oahzxl committed
13
14
15
16
17
18
19
def _delete_free_var_from_last_use(user_to_last_uses):
    for key, value in user_to_last_uses.items():
        for n in value:
            if n.op == 'placeholder':
                user_to_last_uses[key].remove(n)


oahzxl's avatar
oahzxl committed
20
21
22
23
24
25
class NodeIndexTracer(object):
    def __init__(self, gm) -> None:
        self.gm = gm
        self.nodes_list = list(gm.graph.nodes)
        self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))] 
        self.idx_trace_equal = []
oahzxl's avatar
oahzxl committed
26
        self.idx_view_list = []
oahzxl's avatar
oahzxl committed
27
        self.idx_count = -1
oahzxl's avatar
oahzxl committed
28

oahzxl's avatar
oahzxl committed
29
    def _add_index(self):
oahzxl's avatar
oahzxl committed
30
31
32
33
34
35
        """
        Update the count and return it. To record the idx number.
        
        Returns:
            idx_count: int
        """        
oahzxl's avatar
oahzxl committed
36
        self.idx_count += 1
oahzxl's avatar
oahzxl committed
37
        return self.idx_count
oahzxl's avatar
oahzxl committed
38

oahzxl's avatar
oahzxl committed
39
    def _inherit_computation(self, node_from, node_to):
oahzxl's avatar
oahzxl committed
40
41
42
43
44
45
46
47
48
        """
        Inherit computed dim from node_from to node_to.
        If a dim in node_from is marked as computed and exists in node_to,
        still mark it as computed in node_to.

        Args:
            node_from (node): node to be inherited
            node_to (node): new node to inherit
        """        
oahzxl's avatar
oahzxl committed
49
50
        _, compute_from = self._find_trace_from_node(node_from)
        idx_to, compute_to = self._find_trace_from_node(node_to)
oahzxl's avatar
oahzxl committed
51
        for i in compute_from:
oahzxl's avatar
oahzxl committed
52
            if i in idx_to and i not in compute_to:
oahzxl's avatar
oahzxl committed
53
54
                compute_to.append(i)
    
oahzxl's avatar
oahzxl committed
55
    def _mark_idx_equal(self, idx1, idx2):
oahzxl's avatar
oahzxl committed
56
57
58
59
60
61
62
        """
        Mark 2 index to be equal.

        Args:
            idx1 (int): index count.
            idx2 (int): index count.
        """        
oahzxl's avatar
oahzxl committed
63
64
        self.idx_trace_equal.append((idx1, idx2))
        
oahzxl's avatar
oahzxl committed
65
    def _mark_computation(self, node, idx, dim):
oahzxl's avatar
oahzxl committed
66
67
68
69
70
71
72
73
        """
        Mark some dims of node as computed.

        Args:
            node (node)
            idx (int): node index
            dim (list or int): dims to be marked as computed
        """        
oahzxl's avatar
oahzxl committed
74
        input_node_idx_trace = self._find_idx_trace_from_node(node)
oahzxl's avatar
oahzxl committed
75
76
77
78
        if isinstance(dim, int):
            dim = [dim]
        for d in dim:
            cur_idx = input_node_idx_trace[d]
oahzxl's avatar
oahzxl committed
79
80
            if cur_idx not in self.idx_trace_list[idx]['compute']:
                self.idx_trace_list[idx]['compute'].append(cur_idx)
oahzxl's avatar
oahzxl committed
81
    
oahzxl's avatar
oahzxl committed
82
    def _find_trace_from_node(self, node):
oahzxl's avatar
oahzxl committed
83
84
85
86
87
88
89
90
91
        """
        Find node idx and compute trace by the node.

        Args:
            node (node)
        Returns:
            idx (list): idx of the node
            compute (list): computed idx of the node.
        """        
oahzxl's avatar
oahzxl committed
92
93
94
95
        node_idx = _find_idx_by_name(node.name, self.nodes_list)
        node_dict = self.idx_trace_list[node_idx]
        return node_dict['idx'], node_dict['compute']
    
oahzxl's avatar
oahzxl committed
96
    def _find_idx_trace_from_node(self, node):
oahzxl's avatar
oahzxl committed
97
98
99
100
101
102
103
104
        """
        Find node idx trace by the node.

        Args:
            node (node)
        Returns:
            idx (list): idx of the node
        """ 
oahzxl's avatar
oahzxl committed
105
        node_idx = _find_idx_by_name(node.name, self.nodes_list)
oahzxl's avatar
oahzxl committed
106
107
        return self.idx_trace_list[node_idx]['idx']
    
oahzxl's avatar
oahzxl committed
108
    def _find_compute_trace_from_node(self, node):
oahzxl's avatar
oahzxl committed
109
110
111
112
113
114
115
116
        """
        Find node compute trace by the node.

        Args:
            node (node)
        Returns:
            compute (list): computed idx of the node.
        """ 
oahzxl's avatar
oahzxl committed
117
118
        node_idx = _find_idx_by_name(node.name, self.nodes_list)
        return self.idx_trace_list[node_idx]['compute']
oahzxl's avatar
oahzxl committed
119
    
oahzxl's avatar
oahzxl committed
120
    def _assign_index_as_input(self, node, node_idx):
oahzxl's avatar
oahzxl committed
121
122
123
124
125
126
127
        """
        Assign node's trace as its input node.

        Args:
            node (node)
            node_idx (int)
        """        
oahzxl's avatar
oahzxl committed
128
129
130
131
132
133
        input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
        input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx']
        
        new_idx_trace = copy.deepcopy(input_node_idx_trace)
        self.idx_trace_list[node_idx]['idx'] = new_idx_trace
    
oahzxl's avatar
oahzxl committed
134
    def _assign_all_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
135
136
137
138
139
140
141
        """
        Add new index for all node's dims.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
142
143
144
        shape = node.meta['tensor_meta'].shape
        new_trace = []
        for _ in shape:
oahzxl's avatar
oahzxl committed
145
            new_trace.append(self._add_index())
oahzxl's avatar
oahzxl committed
146
147
        self.idx_trace_list[node_idx]['idx'] = new_trace   

oahzxl's avatar
oahzxl committed
148
    def _assign_transpose_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
149
150
151
152
153
154
155
156
157
        """
        Assign index for transpose op.
        1. swap input's dim according to transpose args
        2. inherit input's computation

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
158
        tranpose_dim = node.args[1:]
oahzxl's avatar
oahzxl committed
159
        input_node_idx_trace = self._find_idx_trace_from_node(node.args[0])
oahzxl's avatar
oahzxl committed
160
161
162
163
164
165
        
        new_idx_trace = copy.deepcopy(input_node_idx_trace)
        new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]]
        new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]]

        self.idx_trace_list[node_idx]['idx'] = new_idx_trace
oahzxl's avatar
oahzxl committed
166
        self._inherit_computation(node.args[0], node)
oahzxl's avatar
oahzxl committed
167
        
oahzxl's avatar
oahzxl committed
168
    def _assign_permute_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
169
170
171
172
173
174
175
176
177
        """
        Assign index for permute op.
        1. swap input's dim according to permute args
        2. inherit input's computation

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
178
        permute_dim = node.args[1:]
oahzxl's avatar
oahzxl committed
179
        input_node_idx_trace = self._find_idx_trace_from_node(node.args[0])
oahzxl's avatar
oahzxl committed
180
181
182
183
184
185
        
        new_idx_trace = copy.deepcopy(input_node_idx_trace)
        for idx, d in enumerate(permute_dim):
            new_idx_trace[idx] = input_node_idx_trace[d]

        self.idx_trace_list[node_idx]['idx'] = new_idx_trace
oahzxl's avatar
oahzxl committed
186
        self._inherit_computation(node.args[0], node)
oahzxl's avatar
oahzxl committed
187
        
oahzxl's avatar
oahzxl committed
188
    def _assign_linear_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
189
190
191
192
193
194
195
196
197
198
        """
        Assign index for linear op.
        1. copy trace from input node and change last index accroding to weight
        2. mark equal for input node last index, weight first dim and bias dim.
        3. inherit input's computation, mark computation for last dim.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
199
        input_node, weight, bias = node.args
oahzxl's avatar
oahzxl committed
200
201
        input_node_idx_trace = self._find_idx_trace_from_node(input_node)
        weight_idx_trace = self._find_idx_trace_from_node(weight)
oahzxl's avatar
oahzxl committed
202
203
204
205
206
        
        new_idx_trace = copy.deepcopy(input_node_idx_trace)
        new_idx_trace[-1] = weight_idx_trace[1]
        self.idx_trace_list[node_idx]['idx'] = new_idx_trace

oahzxl's avatar
oahzxl committed
207
208
209
        self._inherit_computation(input_node, node)
        self._mark_computation(node, node_idx, [-1])
        self._mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0])
oahzxl's avatar
oahzxl committed
210
211
        
        if bias:
oahzxl's avatar
oahzxl committed
212
213
            bias_idx_trace = self._find_idx_trace_from_node(bias)
            self._mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
oahzxl's avatar
oahzxl committed
214

oahzxl's avatar
oahzxl committed
215
    def _assign_matmul_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
216
217
218
219
220
221
222
223
224
225
        """
        Assign index for matmul op.
        1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
        2. mark equal for input matmul_left -1 index and matmul_right -2 dim.
        3. inherit matmul_left and matmul_right computation, mark computation for last dim.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
226
        matmul_left, matmul_right = node.args
oahzxl's avatar
oahzxl committed
227
228
        matmul_left_idx_trace = self._find_idx_trace_from_node(matmul_left)
        matmul_right_idx_trace = self._find_idx_trace_from_node(matmul_right)
oahzxl's avatar
oahzxl committed
229
230
231
232
233
234
        
        assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace))
        new_idx_trace = copy.deepcopy(matmul_left_idx_trace)
        new_idx_trace[-1] = matmul_right_idx_trace[-1]
        self.idx_trace_list[node_idx]['idx'] = new_idx_trace

oahzxl's avatar
oahzxl committed
235
236
237
238
        self._inherit_computation(matmul_left, node)
        self._inherit_computation(matmul_right, node)
        self._mark_computation(node, node_idx, [-1])
        self._mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
oahzxl's avatar
oahzxl committed
239

oahzxl's avatar
oahzxl committed
240
    def _assign_layernorm_index(self, node, idx):
oahzxl's avatar
oahzxl committed
241
242
243
244
245
246
247
248
249
        """
        Assign index for layernorm op.
        1. assign index as input node
        2. inherit computation and mark last 2 dims as computed.

        Args:
            node (node)
            node_idx (int)
        """
oahzxl's avatar
oahzxl committed
250
251
252
        self._assign_index_as_input(node, idx)
        self._inherit_computation(node.args[0], node)
        self._mark_computation(node, idx, [-1, -2])
oahzxl's avatar
oahzxl committed
253
    
oahzxl's avatar
oahzxl committed
254
    def _assign_elementwise_index(self, node, idx):
oahzxl's avatar
oahzxl committed
255
256
257
258
259
260
261
262
263
        """
        Assign index for element-wise op (eg. relu sigmoid add mul).
        1. assign index as input node
        2. inherit computation from all input nodes.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
264
        self._assign_index_as_input(node, idx)
oahzxl's avatar
oahzxl committed
265
266
        for node_in in node.args:
            if type(node_in) not in (int, float):
oahzxl's avatar
oahzxl committed
267
                self._inherit_computation(node_in, node)
oahzxl's avatar
oahzxl committed
268
                
oahzxl's avatar
oahzxl committed
269
    def _assign_softmax_index(self, node, idx):
oahzxl's avatar
oahzxl committed
270
271
272
273
274
275
276
277
278
        """
        Assign index for softmax op.
        1. assign index as input node
        2. inherit computation and mark softmax dim as computed.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
279
280
281
        self._assign_index_as_input(node, idx)
        self._inherit_computation(node.args[0], node)
        self._mark_computation(node, idx, [node.kwargs['dim']])
oahzxl's avatar
oahzxl committed
282

oahzxl's avatar
oahzxl committed
283
    def _assign_view_reshape_index(self, node, node_idx):
oahzxl's avatar
oahzxl committed
284
285
286
287
288
289
        """
        Assign index for view and reshape op.
        1. get origin shape and target shape by meta info.
        2. compute the real value of -1 in target shape.
        3. determine changed dim, and assgin index for generated dim.
        4. log changed dim and generated dim for restore
oahzxl's avatar
oahzxl committed
290
291
        5. inherit computation.
        6. TODO: look into view list to see whether the view is associated with other,
oahzxl's avatar
oahzxl committed
292
293
294
295
296
297
           if so assgin equal dim according to previous view.

        Args:
            node (node)
            node_idx (int)
        """  
oahzxl's avatar
oahzxl committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        # get data, turn into number
        origin_node = node.args[0]
        origin_shape = origin_node.meta['tensor_meta'].shape
        target_shape = []
        for i in range(1, len(node.args)):
            if isinstance(node.args[i], int):
                target_shape.append(node.args[i])
            else:
                target_shape.append(node.args[i].meta['fwd_out'][0])

        # compute the value of -1
        if -1 in target_shape:
            origin_product = 1
            for i in origin_shape:
                origin_product *= i
            target_product = -1
            for i in target_shape:
                target_product *= i
            shape_idx = target_shape.index(-1)
            target_shape[shape_idx] = origin_product // target_product

        # determine changed dim
        len_diff = len(origin_shape) - len(target_shape)
        if len_diff == 1:
            # dim merge
            dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
            dim_to = [dim_equal.index(False)]
            dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
        elif len_diff == -1:
            # dim expand
            dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
            dim_from = [dim_equal.index(False)]
            dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
        else:
            raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented")

        # get new index
oahzxl's avatar
oahzxl committed
335
        origin_trace = self._find_idx_trace_from_node(origin_node)
oahzxl's avatar
oahzxl committed
336
337
338
339
340
        new_trace = copy.deepcopy(origin_trace)
        dim_from.reverse()
        for i in dim_from:
            new_trace.pop(i)
        for i in dim_to:
oahzxl's avatar
oahzxl committed
341
            new_trace.insert(i, self._add_index())
oahzxl's avatar
oahzxl committed
342
343
344
        self.idx_trace_list[node_idx]['idx'] = new_trace
        
        # inherit computation
oahzxl's avatar
oahzxl committed
345
346
        self._inherit_computation(origin_node, node)
        compute_log = self._find_compute_trace_from_node(origin_node)
oahzxl's avatar
oahzxl committed
347
348
349
        for i in dim_from:
            if origin_trace[i] in compute_log:
                for j in dim_to:
oahzxl's avatar
oahzxl committed
350
                    self._mark_computation(node, node_idx, [j])
oahzxl's avatar
oahzxl committed
351
352
                break
        
oahzxl's avatar
oahzxl committed
353
        # log view, not used now
oahzxl's avatar
oahzxl committed
354
355
356
357
358
        view_dict = {"idx_from": [origin_trace[i] for i in dim_from],
                     "dim_from": dim_from,
                     "idx_to": [new_trace[i] for i in dim_to],
                     "dim_to": dim_to}
        self.idx_view_list.append(view_dict) 
oahzxl's avatar
oahzxl committed
359
360
361
362
363
364
365
366
367
368
369
    
    def _merge_equal_idx(self):
        idx_equal = copy.deepcopy(self.idx_trace_equal)
        idx_equal.reverse()
        for idx in idx_equal:
            merge_to = min(idx)
            merge_from = max(idx)
            for trace in self.idx_trace_list:
                if merge_from in trace['idx']:
                    trace['idx'] = [merge_to if i == merge_from else i for i in trace['idx']]
    
oahzxl's avatar
oahzxl committed
370
371
372
    def trace_node_idx(self):
        for idx, node in enumerate(self.nodes_list):
            if node.op == 'placeholder':
oahzxl's avatar
oahzxl committed
373
                self._assign_all_index(node, idx)
oahzxl's avatar
oahzxl committed
374
375
            elif node.op == 'call_method':
                if 'transpose' in node.name:
oahzxl's avatar
oahzxl committed
376
                    self._assign_transpose_index(node, idx)
oahzxl's avatar
oahzxl committed
377
                elif 'permute' in node.name:
oahzxl's avatar
oahzxl committed
378
                    self._assign_permute_index(node, idx)
oahzxl's avatar
oahzxl committed
379
                elif 'view' in node.name or 'reshape' in node.name:
oahzxl's avatar
oahzxl committed
380
                    self._assign_view_reshape_index(node, idx)
oahzxl's avatar
oahzxl committed
381
382
383
384
                else:
                    raise NotImplementedError(node.name, "method not implemented yet!")
            elif node.op == 'call_function':
                if 'linear' in node.name:
oahzxl's avatar
oahzxl committed
385
                    self._assign_linear_index(node, idx)
oahzxl's avatar
oahzxl committed
386
                elif 'matmul' in node.name:
oahzxl's avatar
oahzxl committed
387
                    self._assign_matmul_index(node, idx)
oahzxl's avatar
oahzxl committed
388
                elif 'softmax' in node.name:
oahzxl's avatar
oahzxl committed
389
                    self._assign_softmax_index(node, idx)
oahzxl's avatar
oahzxl committed
390
                elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
oahzxl's avatar
oahzxl committed
391
                    self._assign_elementwise_index(node, idx)
oahzxl's avatar
oahzxl committed
392
393
394
395
396
397
398
                elif 'getattr' in node.name:
                    continue # get attr like shape
                elif 'getitem' in node.name:
                    continue # get item in list
                else:
                    raise NotImplementedError(node.name, "function not implemented yet!")
            elif node.op == 'call_module':
oahzxl's avatar
oahzxl committed
399
                if any(n in node.name for n in ['layernorm', 'norm']):
oahzxl's avatar
oahzxl committed
400
                    self._assign_layernorm_index(node, idx)
oahzxl's avatar
oahzxl committed
401
402
403
                else:
                    raise NotImplementedError(node.name, "module not implemented yet!")
            elif node.op == 'get_attr':
oahzxl's avatar
oahzxl committed
404
                self._assign_all_index(node, idx) # get param
oahzxl's avatar
oahzxl committed
405
406
            elif node.op == 'output':
                continue
oahzxl's avatar
oahzxl committed
407
408
            else:
                raise NotImplementedError(node.op, "op not implemented yet!")
oahzxl's avatar
oahzxl committed
409
        self._merge_equal_idx()
oahzxl's avatar
oahzxl committed
410

oahzxl's avatar
oahzxl committed
411

oahzxl's avatar
oahzxl committed
412
413
414
class MemoryEstimator(object):
    def __init__(self) -> None:
        pass
oahzxl's avatar
oahzxl committed
415

oahzxl's avatar
oahzxl committed
416
417
418
419
    def _get_meta_node_size(self, x):
        x = x.meta['tensor_meta']
        x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
        return x
oahzxl's avatar
oahzxl committed
420

oahzxl's avatar
oahzxl committed
421
    def _get_output_node(self, n):
oahzxl's avatar
oahzxl committed
422
        fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
oahzxl's avatar
oahzxl committed
423
424
425
426
427
428
429
430
431
432
433
434
        out_size = activation_size(fwd_out)
        out_node = [n.name] if out_size > 0 else []
        return out_size, out_node
    
    def _get_output_node_size(self, n):
        return self._get_output_node(n)[0]
    
    def _add_active_node(self, n, active_list):
        new_active = self._get_output_node(n)[1]
        for i in new_active:
            if i not in active_list:
                active_list.append(i)
oahzxl's avatar
oahzxl committed
435

oahzxl's avatar
oahzxl committed
436
437
438
439
440
441
442
443
444
445
446
    def _get_delete_node(self, user, user_to_last_uses):
        delete_size = 0
        delete_node = []
        if user.op not in ('placeholder', 'output'):
            nodes_to_delete = user_to_last_uses.get(user, [])
            if len(nodes_to_delete):
                out_node = [self._get_output_node(i) for i in nodes_to_delete]
                delete_size = sum([i[0] for i in out_node])
                for i in range(len(out_node)):
                    if out_node[i][0] > 0:
                        delete_node.append(out_node[i][1][0])
oahzxl's avatar
oahzxl committed
447
448
                    elif nodes_to_delete[i].op == 'placeholder':
                        delete_node.append(nodes_to_delete[i].name)
oahzxl's avatar
oahzxl committed
449
450
        return delete_size, delete_node
    
oahzxl's avatar
oahzxl committed
451
    def _get_delete_node_size(self, user, user_to_last_uses):
oahzxl's avatar
oahzxl committed
452
453
        return self._get_delete_node(user, user_to_last_uses)[0]
    
oahzxl's avatar
oahzxl committed
454
    def _remove_deactive_node(self, user, user_to_last_uses, active_list):
oahzxl's avatar
oahzxl committed
455
456
457
        delete_node = self._get_delete_node(user, user_to_last_uses)[1]
        for i in delete_node:
            active_list.remove(i)
oahzxl's avatar
oahzxl committed
458

oahzxl's avatar
oahzxl committed
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    def _get_last_usr(self, nodes):
        node_to_last_use: Dict[Node, Node] = {}
        user_to_last_uses: Dict[Node, List[Node]] = {}

        def register_last_uses(n: Node, user: Node):
            if n not in node_to_last_use:
                node_to_last_use[n] = user
                user_to_last_uses.setdefault(user, []).append(n)

        for node in reversed(nodes):
            map_arg(node.args, lambda n: register_last_uses(n, node))
            map_arg(node.kwargs, lambda n: register_last_uses(n, node))
        return user_to_last_uses

    def _get_contiguous_memory(self, node, not_contiguous_list, delete=False):
        mem = 0
        not_contiguous_ops = ['transpose', 'permute']

oahzxl's avatar
oahzxl committed
477
        if node.op == 'call_function' and any(n in node.name for n in ['matmul', 'reshape']):
oahzxl's avatar
oahzxl committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
            for n in node.args:
                if n in not_contiguous_list:
                    # matmul won't change origin tensor, but create a tmp copy
                    mem += self._get_output_node_size(n)
        elif node.op == 'call_module':
            for n in node.args:
                if n in not_contiguous_list:
                    # module will just make origin tensor to contiguous
                    if delete:
                        not_contiguous_list.remove(n)
        elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops):
            if node not in not_contiguous_list:
                not_contiguous_list.append(node)
        elif any(i in node.args for i in not_contiguous_list):
            if node not in not_contiguous_list:
                not_contiguous_list.append(node)

        return mem

    def _get_chunk_ratio(self, node, chunk_dim, chunk_size):
        shape = node.meta['tensor_meta'].shape
        chunk_ratio = float(chunk_size) / shape[chunk_dim]
        return chunk_ratio


    def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node):
        if user.op in ('placeholder', 'output'):
            return 0
        nodes_to_delete = user_to_last_uses.get(user, [])
        delete_size = 0
        for n in nodes_to_delete:
            node_idx = _find_idx_by_name(n.name, node_list)
            if start_node <= node_idx < end_node:
                delete_size += self._get_output_node_size(n) * chunk_ratio
        return delete_size
oahzxl's avatar
oahzxl committed
513
514


oahzxl's avatar
oahzxl committed
515
516
517
518
519
520
521
522
523
    def _print_mem_log(self, log, nodes, title=None):
        if title:
            print(title)
        for idx, (l, n) in enumerate(zip(log, nodes)):
            print("%s:%.2f \t" % (n.name, l), end='')
            if (idx + 1) % 3 == 0:
                print("")
        print("\n")

oahzxl's avatar
oahzxl committed
524
    def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=None, end_nodes=None, chunk_dims=None, chunk_sizes=None):
oahzxl's avatar
oahzxl committed
525
526
527
        act_memory = 0.0
        act_memory_peak_log = []
        act_memory_after_node_log = []
oahzxl's avatar
oahzxl committed
528
529
        active_node_list = []
        active_node_list_log = []
oahzxl's avatar
oahzxl committed
530
        not_contiguous_list = []
oahzxl's avatar
oahzxl committed
531
        node_list = list(gm.graph.nodes)
oahzxl's avatar
oahzxl committed
532
533
534
        user_to_last_uses = self._get_last_usr(node_list)
        user_to_last_uses_no_free_var = self._get_last_usr(node_list)
        _delete_free_var_from_last_use(user_to_last_uses_no_free_var)
oahzxl's avatar
oahzxl committed
535
536
537
538
        
        use_chunk = all(i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes])
        chunk_within = False
        chunk_region_idx = 0
oahzxl's avatar
oahzxl committed
539
540
541
542
        chunk_ratio = 1 # use it to estimate chunk mem

        for idx, node in enumerate(node_list):
            # if node in chunk start nodes, change chunk ratio and add chunk_tensor
oahzxl's avatar
oahzxl committed
543
544
545
546
            if use_chunk and idx in start_nodes:
                chunk_within = True
                chunk_ratio = self._get_chunk_ratio(node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx])
                act_memory += self._get_output_node_size(node_list[end_nodes[chunk_region_idx]]) / (1024 ** 2)
oahzxl's avatar
oahzxl committed
547
548
549
550
551
                
            # if node is placeholder, just add the size of the node
            if node.op == 'placeholder':
                act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
                act_memory_peak_log.append(act_memory)
oahzxl's avatar
oahzxl committed
552
                active_node_list.append(node.name)
oahzxl's avatar
oahzxl committed
553
554
555
556
            # skip output
            elif node.op == 'output':
                continue
            # node is an operation, calculate tmp, output node and delete node memory
oahzxl's avatar
oahzxl committed
557
            else:
oahzxl's avatar
oahzxl committed
558
559
560
561
562
563
564
                # forward memory
                act_memory += self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
                act_memory += self._get_output_node_size(node) * chunk_ratio / (1024 ** 2)
                # record max act memory
                act_memory_peak_log.append(act_memory)
                # delete useless memory
                act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
oahzxl's avatar
oahzxl committed
565
                if chunk_within:
oahzxl's avatar
oahzxl committed
566
                    act_memory -= self._get_chunk_delete_node_size(
oahzxl's avatar
oahzxl committed
567
                        node, user_to_last_uses_no_free_var, chunk_ratio, node_list, 
oahzxl's avatar
oahzxl committed
568
                        start_nodes[chunk_region_idx], end_nodes[chunk_region_idx]) / (1024 ** 2)
oahzxl's avatar
oahzxl committed
569
                else:
oahzxl's avatar
oahzxl committed
570
                    act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var) / (1024 ** 2)
oahzxl's avatar
oahzxl committed
571
572
573
574
575
576
577

            # log active node
            self._add_active_node(node, active_node_list)
            self._remove_deactive_node(node, user_to_last_uses, active_node_list)

            # if node in chunk end nodes, restore chunk settings
            if use_chunk and idx in end_nodes:
oahzxl's avatar
oahzxl committed
578
                act_memory -= self._get_output_node_size(node) * chunk_ratio / (1024 ** 2)
oahzxl's avatar
oahzxl committed
579
                chunk_within = False
oahzxl's avatar
oahzxl committed
580
                chunk_ratio = 1
oahzxl's avatar
oahzxl committed
581
                chunk_region_idx += 1
oahzxl's avatar
oahzxl committed
582
            
oahzxl's avatar
oahzxl committed
583
            act_memory_after_node_log.append(act_memory)
oahzxl's avatar
oahzxl committed
584
            active_node_list_log.append(copy.deepcopy(active_node_list))
oahzxl's avatar
oahzxl committed
585

oahzxl's avatar
oahzxl committed
586
        print("with chunk" if use_chunk else "without chunk")
oahzxl's avatar
oahzxl committed
587
588
        self._print_mem_log(act_memory_peak_log, node_list, "peak")
        self._print_mem_log(act_memory_after_node_log, node_list, "after")
oahzxl's avatar
oahzxl committed
589

oahzxl's avatar
oahzxl committed
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        # param_memory = parameter_size(gm)
        # all_memory = act_memory + param_memory
        return act_memory_peak_log, act_memory_after_node_log, active_node_list_log


class ChunkRegionSearch(object):
    def __init__(self, gm) -> None:
        self.gm = gm
        self.node_list = list(gm.graph.nodes)
        self.memory_estimator = MemoryEstimator()
        self.index_tracer = NodeIndexTracer(gm)
        self.index_tracer.trace_node_idx()

    def _find_peak_node(self, mem_peak):
        max_value = max(mem_peak)
        max_idx = [mem_peak.index(max_value)]
        return max_idx
    
    def _get_free_var(self):
        free_var_idx = []
        for idx, n in enumerate(self.node_list):
            if n.op == 'placeholder':
                free_var_idx.append(idx)
        return free_var_idx
    
    def _get_min_free_var(self, active_node_list, free_vars):
        min_len = 999
        for idx, n in enumerate(active_node_list):
            if idx in free_vars:
                continue
            if len(n) < min_len:
                min_len = len(n)
        return min_len
    
    def _search_max_chunk_region(self, active_node, peak_node):
        free_vars = self._get_free_var()
        min_var = self._get_min_free_var(active_node, free_vars)
        
        # from peak_node to free_var
        chunk_region_start = None
        for i in range(peak_node, -1, -1):
            if len(active_node[i]) == min_var:
                chunk_region_start = i + 1
                break
            if i in free_vars or i == 0:
                raise RuntimeError()
        # from peak_node to len-2
        chunk_region_end = None
        for i in range(peak_node, len(active_node) - 1):
            if len(active_node[i]) == min_var:
                chunk_region_end = i - 1
                break
            if i in free_vars or i == 0:
                raise RuntimeError()
        return chunk_region_start, chunk_region_end
    
    def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
        possible_chunk_region = []
        for before_idx in range(max_chunk_region[0], peak_node):
            for after_idx in range(peak_node, max_chunk_region[1]):
                # skip non compute nodes
                if any(op in ['placeholder', 'get_attr', 'output'] for op in 
                       [self.node_list[before_idx].op, self.node_list[after_idx].op]):
                    continue
                if any(any(i in name for i in ['getitem', 'getattr']) for name in 
                       [self.node_list[before_idx].name, self.node_list[after_idx].name]):
                    continue
                
                # select free dim
                before_trace = self.index_tracer.idx_trace_list[before_idx]
                after_trace = self.index_tracer.idx_trace_list[after_idx]
                free_dim = []
                for i in range(min(len(before_trace['idx']), len(after_trace['idx']))):
                   if (before_trace['idx'][i] == after_trace['idx'][i] and 
                       before_trace['idx'][i] not in before_trace['compute'] and
                       after_trace['idx'][i] not in after_trace['compute']):
                       free_dim.append(i)
                possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim})
        return possible_chunk_region
    
    def search_region(self):
        mem_peak, mem_after, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm)
        peak_nodes = self._find_peak_node(mem_peak)
        for idx, peak_node in enumerate(peak_nodes):
            max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
            possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
oahzxl's avatar
oahzxl committed
676
677


oahzxl's avatar
oahzxl committed
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
    new_shape = "["
    for idx, i in enumerate(shape):
        if idx == chunk_dim:
            new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name)
        else:
            new_shape += ":"
        new_shape += ", "
    new_shape = new_shape[:-2] + "]"
    return new_shape


def _get_first_non_single_dim(shape):
    for idx, i in enumerate(shape):
        if i == 1:
            continue
        else:
            return idx
    raise RuntimeError("can not get first non single dim for shape", shape)


def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2):
    if len(chunk_input_meta) == 1:
        node = chunk_input_meta[0]
        node_shape = node.meta['tensor_meta'].shape
        chunk_dim = _get_first_non_single_dim(node_shape)
        chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
        out_shape = str(list(chunk_output.meta['tensor_meta'].shape))
        
        context = "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" % (
            out_shape, node.name, node.name, chunk_size)
        context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
        context += "    chunk_tensor = %s%s\n" % (node.name, chunk_slice)
    else:
        raise NotImplementedError("input with size %d not implemented" % len(chunk_input_meta))
oahzxl's avatar
oahzxl committed
713
714
715
    return context


oahzxl's avatar
oahzxl committed
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
def _gen_loop_end(chunk_outputs, chunk_inputs, node_list):
    chunk_inputs_name = chunk_inputs[0].name
    chunk_outputs_name = chunk_outputs.name
    chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
    chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape
    chunk_dim = _get_first_non_single_dim(chunk_output_shape)
    chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape)
    context = "    chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)

    context += chunk_outputs_name + " = chunk_result;  chunk_result = None;  chunk_size = None"
    
    # determine if its the last use for chunk input
    users_name = list(chunk_inputs[0].users.keys())
    if all([_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in users_name]):
        context += ";  %s = None" % chunk_inputs_name

    context += "\n"
oahzxl's avatar
oahzxl committed
733
734
    return context

oahzxl's avatar
init  
oahzxl committed
735
736
737
738
739
740
741
742
743
744
745
746
747
748

def _find_input_and_output_nodes(nodes: List[Node]):
    """
    Find the input and output node names which are not found in the given list of nodes.
    """
    input_nodes = []
    output_nodes = []

    # if a node has an input node which is not in the node list
    # we treat that input node as the input of the checkpoint function
    for node in nodes:
        for input_node in node._input_nodes.keys():
            node_repr = repr(input_node)
            if input_node not in nodes and node_repr not in input_nodes:
oahzxl's avatar
oahzxl committed
749
                input_nodes.append(input_node)
oahzxl's avatar
init  
oahzxl committed
750
751
752
753
754
755
756

    # if a node has a user node which is not in the node list
    # we treat that user node as the node receiving the current node output
    for node in nodes:
        for output_node in node.users.keys():
            node_repr = repr(node)
            if output_node not in nodes and node_repr not in output_nodes:
oahzxl's avatar
oahzxl committed
757
                output_nodes.append(output_node)
oahzxl's avatar
init  
oahzxl committed
758
759
760
761

    return input_nodes, output_nodes


oahzxl's avatar
oahzxl committed
762
763
764
765
766
def _find_idx_by_name(name, nodes_list):
    for idx, node in enumerate(nodes_list):
        if node.name == name:
            return idx
    raise RuntimeError("name %s not found in node list" % name)
oahzxl's avatar
init  
oahzxl committed
767
768


oahzxl's avatar
oahzxl committed
769
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph):
oahzxl's avatar
init  
oahzxl committed
770
771
772
773
774
775
776
777
778
779
780
781
782
    """Emit code with nested activation checkpoint
    When we detect some of the node.activation_checkpoint is a List, we will use
    this function to emit the activation checkpoint codes.

    Args:
        body: forward code
        ckpt_func: checkpoint functions code
        nodes: graph.nodes
        emit_node_func: function to emit node
        delete_unused_value_func: function to remove the unused value
    """

    # find the offload regions
oahzxl's avatar
oahzxl committed
783
    chunk_regions = [(58, 62)]
oahzxl's avatar
oahzxl committed
784
785
786
787
788
    chunk_starts = [item[0] for item in chunk_regions]
    chunk_ends = [item[1] for item in chunk_regions]
    chunk_inputs = []
    chunk_outputs = []
    within_chunk_region = False
oahzxl's avatar
init  
oahzxl committed
789
790

    node_list = list(nodes)
oahzxl's avatar
oahzxl committed
791

oahzxl's avatar
oahzxl committed
792
793
    memory_estimator = MemoryEstimator()
    memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
oahzxl's avatar
oahzxl committed
794
    memory_estimator.estimate_chunk_inference_mem(meta_graph)
oahzxl's avatar
oahzxl committed
795

oahzxl's avatar
oahzxl committed
796
797
    node_index_tracer = NodeIndexTracer(meta_graph)
    node_index_tracer.trace_node_idx()
oahzxl's avatar
oahzxl committed
798
799
800
    
    chunk_region_search = ChunkRegionSearch(meta_graph)
    chunk_region_search.search_region()
oahzxl's avatar
init  
oahzxl committed
801
802

    # find the input and output var names for each offload region
oahzxl's avatar
oahzxl committed
803
    for idx, (start, end) in enumerate(chunk_regions):
oahzxl's avatar
init  
oahzxl committed
804
805
        offload_node_list = node_list[start:end + 1]
        inputs, outputs = _find_input_and_output_nodes(offload_node_list)
oahzxl's avatar
oahzxl committed
806
807
        chunk_inputs.append(inputs)
        chunk_outputs.append(outputs)
oahzxl's avatar
oahzxl committed
808
809
810
811
812
813
814
    chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs]
    chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs]
    chunk_inputs_names = []
    for i in chunk_inputs:
        for j in i:
            chunk_inputs_names.append(j.name)
    
oahzxl's avatar
init  
oahzxl committed
815
816
817
    # this flag is to prevent repeated insert of save tensors
    # hooks definition in ckpt_func
    node_idx = 0
oahzxl's avatar
oahzxl committed
818
    region_idx = 0
oahzxl's avatar
oahzxl committed
819
    while node_idx < len(node_list):
oahzxl's avatar
oahzxl committed
820
        node = node_list[node_idx]
oahzxl's avatar
init  
oahzxl committed
821

oahzxl's avatar
oahzxl committed
822
823
824
825
826
827
        if node_idx in chunk_starts:
            within_chunk_region = True
                
            # add for loop
            chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
            body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]]))
oahzxl's avatar
init  
oahzxl committed
828

oahzxl's avatar
oahzxl committed
829
830
831
        if within_chunk_region:
            emit_node_func(node, body)
            # replace input var with chunk var
oahzxl's avatar
oahzxl committed
832
            if node_idx in chunk_starts:
oahzxl's avatar
oahzxl committed
833
834
835
                body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)')
            body[-1] = '    ' + body[-1]
            delete_unused_value_func(node, body, chunk_inputs_names)
oahzxl's avatar
init  
oahzxl committed
836

oahzxl's avatar
oahzxl committed
837
838
839
840
        else:
            emit_node_func(node, body)
            if node_idx not in chunk_inputs:
                delete_unused_value_func(node, body, chunk_inputs_names)
oahzxl's avatar
init  
oahzxl committed
841

oahzxl's avatar
oahzxl committed
842
843
844
845
        if node_idx in chunk_ends:
            body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list))
            within_chunk_region = False
            region_idx += 1
oahzxl's avatar
init  
oahzxl committed
846

oahzxl's avatar
oahzxl committed
847
        node_idx += 1
oahzxl's avatar
init  
oahzxl committed
848
849
850
851


if CODEGEN_AVAILABLE:

oahzxl's avatar
oahzxl committed
852
    class ChunkCodeGen(CodeGen):
oahzxl's avatar
oahzxl committed
853
854
        def __init__(self, meta_graph):
            super().__init__()
oahzxl's avatar
oahzxl committed
855
            self.meta_graph = meta_graph
oahzxl's avatar
oahzxl committed
856
            self.meta_node = list(meta_graph.graph.nodes)
oahzxl's avatar
init  
oahzxl committed
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957

        def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
            free_vars: List[str] = []
            body: List[str] = []
            globals_: Dict[str, Any] = {}
            wrapped_fns: Dict[str, None] = {}

            # Wrap string in list to pass by reference
            maybe_return_annotation: List[str] = ['']

            def add_global(name_hint: str, obj: Any):
                """Add an obj to be tracked as a global.

                We call this for names that reference objects external to the
                Graph, like functions or types.

                Returns: the global name that should be used to reference 'obj' in generated source.
                """
                if _is_from_torch(obj) and obj != torch.device:    # to support registering torch.device
                    # HACK: workaround for how torch custom ops are registered. We
                    # can't import them like normal modules so they must retain their
                    # fully qualified name.
                    return _get_qualified_name(obj)

                # normalize the name hint to get a proper identifier
                global_name = namespace.create_name(name_hint, obj)

                if global_name in globals_:
                    assert globals_[global_name] is obj
                    return global_name
                globals_[global_name] = obj
                return global_name

            # set _custom_builtins here so that we needn't import colossalai in forward
            _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)

            # Pre-fill the globals table with registered builtins.
            for name, (_, obj) in _custom_builtins.items():
                add_global(name, obj)

            def type_repr(o: Any):
                if o == ():
                    # Empty tuple is used for empty tuple type annotation Tuple[()]
                    return '()'

                typename = _type_repr(o)

                if hasattr(o, '__origin__'):
                    # This is a generic type, e.g. typing.List[torch.Tensor]
                    origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
                    origin_typename = add_global(_type_repr(origin_type), origin_type)

                    if hasattr(o, '__args__'):
                        # Assign global names for each of the inner type variables.
                        args = [type_repr(arg) for arg in o.__args__]

                        if len(args) == 0:
                            # Bare type, such as `typing.Tuple` with no subscript
                            # This code-path used in Python < 3.9
                            return origin_typename

                        return f'{origin_typename}[{",".join(args)}]'
                    else:
                        # Bare type, such as `typing.Tuple` with no subscript
                        # This code-path used in Python 3.9+
                        return origin_typename

                # Common case: this is a regular module name like 'foo.bar.baz'
                return add_global(typename, o)

            def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:

                def _get_repr(arg):
                    # Handle NamedTuples (if it has `_fields`) via add_global.
                    if isinstance(arg, tuple) and hasattr(arg, '_fields'):
                        qualified_name = _get_qualified_name(type(arg))
                        global_name = add_global(qualified_name, type(arg))
                        return f"{global_name}{repr(tuple(arg))}"
                    return repr(arg)

                args_s = ', '.join(_get_repr(a) for a in args)
                kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
                if args_s and kwargs_s:
                    return f'{args_s}, {kwargs_s}'
                return args_s or kwargs_s

            # Run through reverse nodes and record the first instance of a use
            # of a given node. This represents the *last* use of the node in the
            # execution order of the program, which we will use to free unused
            # values
            node_to_last_use: Dict[Node, Node] = {}
            user_to_last_uses: Dict[Node, List[Node]] = {}

            def register_last_uses(n: Node, user: Node):
                if n not in node_to_last_use:
                    node_to_last_use[n] = user
                    user_to_last_uses.setdefault(user, []).append(n)

            for node in reversed(nodes):
                map_arg(node.args, lambda n: register_last_uses(n, node))
                map_arg(node.kwargs, lambda n: register_last_uses(n, node))
oahzxl's avatar
oahzxl committed
958
959
960
            
            _delete_free_var_from_last_use(user_to_last_uses)
            
oahzxl's avatar
init  
oahzxl committed
961
            # NOTE: we add a variable to distinguish body and ckpt_func
oahzxl's avatar
oahzxl committed
962
            def delete_unused_values(user: Node, body, to_keep=[]):
oahzxl's avatar
init  
oahzxl committed
963
964
965
966
967
968
969
970
971
972
973
                """
                Delete values after their last use. This ensures that values that are
                not used in the remainder of the code are freed and the memory usage
                of the code is optimal.
                """
                if user.op == 'placeholder':
                    return
                if user.op == 'output':
                    body.append('\n')
                    return
                nodes_to_delete = user_to_last_uses.get(user, [])
oahzxl's avatar
oahzxl committed
974
                nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
oahzxl's avatar
init  
oahzxl committed
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
                if len(nodes_to_delete):
                    to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
                    body.append(f';  {to_delete_str}\n')
                else:
                    body.append('\n')

            # NOTE: we add a variable to distinguish body and ckpt_func
            def emit_node(node: Node, body):
                maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
                if node.op == 'placeholder':
                    assert isinstance(node.target, str)
                    maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
                    free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
                    raw_name = node.target.replace('*', '')
                    if raw_name != repr(node):
                        body.append(f'{repr(node)} = {raw_name}\n')
                    return
                elif node.op == 'call_method':
                    assert isinstance(node.target, str)
                    body.append(
                        f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
                        f'({_format_args(node.args[1:], node.kwargs)})')
                    return
                elif node.op == 'call_function':
                    assert callable(node.target)
                    # pretty print operators
                    if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
                        assert isinstance(node.args, tuple)
                        body.append(f'{repr(node)}{maybe_type_annotation} = '
                                    f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
                        return

                    # pretty print inplace operators; required for jit.script to work properly
                    # not currently supported in normal FX graphs, but generated by torchdynamo
                    if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
                        body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))};  '
                                    f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
                        return

                    qualified_name = _get_qualified_name(node.target)
                    global_name = add_global(qualified_name, node.target)
                    # special case for getattr: node.args could be 2-argument or 3-argument
                    # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
                    if global_name == 'getattr' and \
                    isinstance(node.args, tuple) and \
                    isinstance(node.args[1], str) and \
                    node.args[1].isidentifier() and \
                    len(node.args) == 2:
                        body.append(
                            f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
                        return
                    body.append(
                        f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
                    if node.meta.get('is_wrapped', False):
                        wrapped_fns.setdefault(global_name)
                    return
                elif node.op == 'call_module':
                    assert isinstance(node.target, str)
                    body.append(f'{repr(node)}{maybe_type_annotation} = '
                                f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
                    return
                elif node.op == 'get_attr':
                    assert isinstance(node.target, str)
                    body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
                    return
                elif node.op == 'output':
                    if node.type is not None:
                        maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
                    body.append(self.generate_output(node.args[0]))
                    return
                raise NotImplementedError(f'node: {node.op} {node.target}')

            # Modified for activation checkpointing
            ckpt_func = []

            # if any node has a list of labels for activation_checkpoint, we
            # will use nested type of activation checkpoint codegen
oahzxl's avatar
oahzxl committed
1052
            emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node, self.meta_graph)
oahzxl's avatar
init  
oahzxl committed
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083

            if len(body) == 0:
                # If the Graph has no non-placeholder nodes, no lines for the body
                # have been emitted. To continue to have valid Python code, emit a
                # single pass statement
                body.append('pass\n')

            if len(wrapped_fns) > 0:
                wrap_name = add_global('wrap', torch.fx.wrap)
                wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
            else:
                wrap_stmts = ''

            if self._body_transformer:
                body = self._body_transformer(body)

            for name, value in self.additional_globals():
                add_global(name, value)

            # as we need colossalai.utils.checkpoint, we need to import colossalai
            # in forward function
            prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
            prologue = ''.join(ckpt_func) + prologue
            prologue = prologue

            code = ''.join(body)
            code = '\n'.join('    ' + line for line in code.split('\n'))
            fn_code = f"""
{wrap_stmts}

{prologue}
oahzxl's avatar
oahzxl committed
1084
1085
{code}"""   
            print(fn_code)
oahzxl's avatar
init  
oahzxl committed
1086
            return PythonCode(fn_code, globals_)