Commit 212b5b1b authored by oahzxl's avatar oahzxl
Browse files

add comments

parent 19cc64b1
from typing import Any, Callable, Dict, Iterable, List, Tuple
from typing import Any, Dict, Iterable, List, Tuple
import torch
from torch.fx.graph import (
......@@ -128,37 +128,42 @@ def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, bod
def emit_code_with_chunk(
body,
nodes,
body: List[str],
nodes: Iterable[Node],
emit_node_func,
delete_unused_value_func,
search_chunk: SearchChunk,
chunk_infos,
chunk_infos: List,
):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
"""
Emit code with chunk according to chunk_infos.
It will generate a for loop in chunk regions, and replace inputs
and outputs of regions with chunked variables.
Args:
body: forward code
ckpt_func: checkpoint functions code
nodes: graph.nodes
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
search_chunk: the class to search all chunks
chunk_infos: store all information about all chunks.
"""
node_list = list(nodes)
chunk_regions = [i["region"] for i in chunk_infos]
chunk_starts = [i[0] for i in chunk_regions]
chunk_ends = [i[1] for i in chunk_regions]
# chunk region
chunk_starts = [i["region"][0] for i in chunk_infos]
chunk_ends = [i["region"][1] for i in chunk_infos]
chunk_inputs = [i["inputs"] for i in chunk_infos]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos]
# chunk inputs
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
j.name for i in chunk_inputs_non_chunk for j in i
]
# chunk outputs
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
......@@ -170,6 +175,7 @@ def emit_code_with_chunk(
while node_idx < len(node_list):
node = node_list[node_idx]
# if is chunk start, generate for loop start
if node_idx in chunk_starts:
within_chunk_region = True
region_idx = chunk_starts.index(node_idx)
......@@ -203,6 +209,7 @@ def emit_code_with_chunk(
if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, chunk_inputs_names)
# generate chunk region end
if node_idx in chunk_ends:
body.append(
_gen_loop_end(
......
......@@ -115,4 +115,4 @@ def test_autochunk_codegen(msa_len, pair_len, max_memory):
if __name__ == "__main__":
_test_autochunk_codegen(0, 32, 64, None)
_test_autochunk_codegen(0, 32, 64, 25)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment