Unverified Commit 63199c66 authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

[autochunk] support transformer (#2526)

parent 6e0faa70
...@@ -3,9 +3,12 @@ from typing import Any, Dict, Iterable, List, Tuple ...@@ -3,9 +3,12 @@ from typing import Any, Dict, Iterable, List, Tuple
import torch import torch
import colossalai import colossalai
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
if CODEGEN_AVAILABLE: AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE:
from torch.fx.graph import ( from torch.fx.graph import (
CodeGen, CodeGen,
PythonCode, PythonCode,
...@@ -272,7 +275,7 @@ def emit_code_with_chunk( ...@@ -272,7 +275,7 @@ def emit_code_with_chunk(
node_idx += 1 node_idx += 1
if CODEGEN_AVAILABLE: if AUTOCHUNK_AVAILABLE:
class AutoChunkCodeGen(CodeGen): class AutoChunkCodeGen(CodeGen):
......
...@@ -8,7 +8,13 @@ from .reorder_graph import ReorderGraph ...@@ -8,7 +8,13 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk from .select_chunk import SelectChunk
from .trace_flow import TraceFlow from .trace_flow import TraceFlow
from .trace_indice import TraceIndice from .trace_indice import TraceIndice
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder from .utils import (
find_chunk_compute_input_and_output_nodes,
get_logger,
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
class SearchChunk(object): class SearchChunk(object):
...@@ -114,6 +120,12 @@ class SearchChunk(object): ...@@ -114,6 +120,12 @@ class SearchChunk(object):
chunk_region_start (int) chunk_region_start (int)
chunk_region_end (int) chunk_region_end (int)
""" """
# check if peak node already in chunkinfo
if chunk_regions is not None:
for i in chunk_regions:
if i["region"][0] < peak_node_idx <= i["region"][1]:
return None
free_vars = self._get_free_var_idx() free_vars = self._get_free_var_idx()
free_var_num = len(free_vars) free_var_num = len(free_vars)
active_node_num = [len(i) for i in active_node] active_node_num = [len(i) for i in active_node]
...@@ -152,55 +164,6 @@ class SearchChunk(object): ...@@ -152,55 +164,6 @@ class SearchChunk(object):
chunk_region_end = region[0] - 1 chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end return chunk_region_start, chunk_region_end
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx]
end_node = self.trace_indice.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
if len(start_traces) > 1:
continue
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
continue
# must have users
if len(end_node.users) == 0:
continue
# check index source align
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
continue
# check index copmute
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
# check index copmute
if not self.trace_flow.check_index_duplicate(chunk_info):
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List: def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
""" """
Search every possible region within the max chunk region. Search every possible region within the max chunk region.
...@@ -228,9 +191,8 @@ class SearchChunk(object): ...@@ -228,9 +191,8 @@ class SearchChunk(object):
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node( if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
self.trace_indice.node_list[end_idx]): self.trace_indice.node_list[end_idx]):
continue continue
# select free dim # select free dim
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx) chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx)
if len(chunk_info) > 0: if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info) possible_chunk_region.extend(chunk_info)
return possible_chunk_region return possible_chunk_region
......
...@@ -5,6 +5,7 @@ from .utils import is_non_compute_node ...@@ -5,6 +5,7 @@ from .utils import is_non_compute_node
class SelectChunk(object): class SelectChunk(object):
def __init__( def __init__(
self, self,
trace_indice: TraceIndice, trace_indice: TraceIndice,
...@@ -17,13 +18,11 @@ class SelectChunk(object): ...@@ -17,13 +18,11 @@ class SelectChunk(object):
self.reorder_graph = reorder_graph self.reorder_graph = reorder_graph
if max_memory is not None: if max_memory is not None:
self.stratge = "fit_memory" self.stratge = "fit_memory"
self.max_memory = max_memory # MB self.max_memory = max_memory # MB
else: else:
self.stratge = "min_memory" self.stratge = "min_memory"
def _select_best_chunk_region( def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak):
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
):
if self.stratge == "min_memory": if self.stratge == "min_memory":
best_region = self._select_min_memory_chunk_region( best_region = self._select_min_memory_chunk_region(
possible_chunk_regions, possible_chunk_regions,
...@@ -44,9 +43,8 @@ class SelectChunk(object): ...@@ -44,9 +43,8 @@ class SelectChunk(object):
raise RuntimeError() raise RuntimeError()
return best_region return best_region
def _select_fit_memory_chunk_region( def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak mem_peak):
):
# stop chunk if max memory satisfy memory limit # stop chunk if max memory satisfy memory limit
if max(mem_peak) < self.max_memory: if max(mem_peak) < self.max_memory:
return None return None
...@@ -63,33 +61,26 @@ class SelectChunk(object): ...@@ -63,33 +61,26 @@ class SelectChunk(object):
if len(possible_chunk_regions) == 0: if len(possible_chunk_regions) == 0:
return None return None
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
# get mem for chunk region # get mem for chunk region
regions_dict = [] regions_dict = []
for region in possible_chunk_regions: for region in possible_chunk_regions:
cur_region = region.copy() cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder( cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
self.trace_indice.node_list, cur_region
)
cur_chunk_infos = chunk_infos + [cur_region] cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_node_list, cur_chunk_infos cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
)[0]
cur_chunk_region_peak = cur_mem_peak[
max_chunk_region[0] : max_chunk_region[1] + 1
]
cur_chunk_region_max_peak = max(cur_chunk_region_peak) cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory: if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append( regions_dict.append({
{ "chunk_info": region,
"chunk_info": region, "chunk_max_mem": cur_chunk_region_max_peak,
"chunk_max_mem": cur_chunk_region_max_peak, "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"chunk_len": self._get_compute_node_num( "reorder_chunk_info": cur_region,
region["region"][0], region["region"][1] "reorder_node_list": cur_node_list,
), })
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
# no region found # no region found
if len(regions_dict) == 0: if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.") raise RuntimeError("Search failed. Try a larger memory threshold.")
...@@ -113,20 +104,13 @@ class SelectChunk(object): ...@@ -113,20 +104,13 @@ class SelectChunk(object):
chunk_size *= 2 chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info] cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
chunk_region_dict["reorder_node_list"], cur_chunk_infos cur_chunk_infos)[0]
)[0] cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
cur_chunk_max_mem = max(
cur_mem_peak[
reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1]
+ 1
]
)
# search exact size # search exact size
chunk_info = chunk_region_dict["chunk_info"] chunk_info = chunk_region_dict["chunk_info"]
chunk_info["chunk_size"] = self._chunk_size_binary_search( chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos chunk_infos)
)
return chunk_info return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos): def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
...@@ -139,12 +123,9 @@ class SelectChunk(object): ...@@ -139,12 +123,9 @@ class SelectChunk(object):
mid = int((left + right) / 2 + 0.5) mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info] cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
chunk_region_dict["reorder_node_list"], cur_chunk_infos cur_chunk_infos)[0]
)[0] cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
cur_chunk_max_mem = max(
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
)
if cur_chunk_max_mem >= self.max_memory: if cur_chunk_max_mem >= self.max_memory:
right = mid - gap right = mid - gap
else: else:
...@@ -153,14 +134,13 @@ class SelectChunk(object): ...@@ -153,14 +134,13 @@ class SelectChunk(object):
def _get_compute_node_num(self, start, end): def _get_compute_node_num(self, start, end):
count = 0 count = 0
for i in self.trace_indice.node_list[start : end + 1]: for i in self.trace_indice.node_list[start:end + 1]:
if not is_non_compute_node(i): if not is_non_compute_node(i):
count += 1 count += 1
return count return count
def _select_min_memory_chunk_region( def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak mem_peak):
):
# remove illegal regions # remove illegal regions
illegal_regions = [] illegal_regions = []
for i in possible_chunk_regions: for i in possible_chunk_regions:
...@@ -173,37 +153,31 @@ class SelectChunk(object): ...@@ -173,37 +153,31 @@ class SelectChunk(object):
if len(possible_chunk_regions) == 0: if len(possible_chunk_regions) == 0:
return None return None
# get max possible chunk region
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
# get mem for chunk region # get mem for chunk region
regions_dict = [] regions_dict_list = []
for region in possible_chunk_regions: for region in possible_chunk_regions:
cur_region = region.copy() cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder( cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
self.trace_indice.node_list, cur_region
)
cur_chunk_infos = chunk_infos + [cur_region] cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_node_list, cur_chunk_infos cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
)[0]
cur_chunk_region_peak = cur_mem_peak[
max_chunk_region[0] : max_chunk_region[1] + 1
]
cur_chunk_region_max_peak = max(cur_chunk_region_peak) cur_chunk_region_max_peak = max(cur_chunk_region_peak)
regions_dict.append( regions_dict_list.append({
{ "chunk_info": region,
"chunk_info": region, "chunk_max_mem": cur_chunk_region_max_peak,
"chunk_max_mem": cur_chunk_region_max_peak, "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"chunk_len": self._get_compute_node_num( "reorder_chunk_info": cur_region,
region["region"][0], region["region"][1] "reorder_node_list": cur_node_list,
), })
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
# select the min mem # select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict] chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
best_region_idx = chunk_max_mem.index(min(chunk_max_mem)) best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
best_region = regions_dict[best_region_idx]["chunk_info"] best_region = regions_dict_list[best_region_idx]["chunk_info"]
if best_region is not None: if best_region is not None:
best_region["chunk_size"] = 1 best_region["chunk_size"] = 1
return best_region return best_region
...@@ -216,9 +190,7 @@ class SelectChunk(object): ...@@ -216,9 +190,7 @@ class SelectChunk(object):
return False return False
for i in chunk_infos: for i in chunk_infos:
region = i["region"] region = i["region"]
if not ( if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
(chunk_region_start > region[1] and chunk_region_end > region[1]) (chunk_region_start < region[0] and chunk_region_end < region[0])):
or (chunk_region_start < region[0] and chunk_region_end < region[0])
):
return False return False
return True return True
...@@ -8,9 +8,9 @@ from .utils import ( ...@@ -8,9 +8,9 @@ from .utils import (
find_chunk_compute_input_and_output_nodes, find_chunk_compute_input_and_output_nodes,
find_idx_by_name, find_idx_by_name,
flat_list, flat_list,
get_node_name,
get_node_shape, get_node_shape,
is_non_compute_node, is_non_compute_node,
is_non_compute_node_except_placeholder,
) )
...@@ -79,43 +79,6 @@ class TraceFlow(object): ...@@ -79,43 +79,6 @@ class TraceFlow(object):
return node_dim return node_dim
return None return None
def check_index_duplicate(self, chunk_infos, return_dim=False):
input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
if inherit_dim:
input_dim_after_node[k] = inherit_dim
for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
if is_non_compute_node_except_placeholder(node):
continue
count = 0
duplicate_dims = []
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))):
duplicate_dim = []
duplicate_flag = False
dim_source = node_trace_source[node_dim]
for k, v in dim_source.items():
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
if k in input_dim_after_node and input_dim_after_node[k] in v:
duplicate_flag = True
duplicate_dim.append((k, v))
duplicate_dims.append(duplicate_dim)
if duplicate_flag:
count += 1
if count > 1:
if return_dim:
return False, duplicate_dims
else:
return False
if return_dim:
return True, None
else:
return True
def _assgin_single_node_flow( def _assgin_single_node_flow(
self, self,
arg_node: Node, arg_node: Node,
...@@ -225,9 +188,12 @@ class TraceFlow(object): ...@@ -225,9 +188,12 @@ class TraceFlow(object):
if flow_flag == False: if flow_flag == False:
return None return None
if len(arg_list) == 2: if len(arg_list) >= 2:
if any(i in cur_node.name for i in ["add", "mul", "truediv"]): # need to mark fix dim
if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]):
for arg in arg_list: for arg in arg_list:
if get_node_shape(arg) is None:
continue
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx): if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
continue continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"] arg_chunk_dim = all_node_info[arg]["chunk_dim"]
...@@ -240,9 +206,8 @@ class TraceFlow(object): ...@@ -240,9 +206,8 @@ class TraceFlow(object):
return None return None
if i not in arg_fix_dim: if i not in arg_fix_dim:
arg_fix_dim.append(i) arg_fix_dim.append(i)
elif "einsum" in cur_node.name: elif any(i == get_node_name(cur_node)
pass for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]):
elif "matmul" in cur_node.name:
pass pass
else: else:
raise NotImplementedError() raise NotImplementedError()
...@@ -426,7 +391,7 @@ class TraceFlow(object): ...@@ -426,7 +391,7 @@ class TraceFlow(object):
reshape_size = {} reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]: for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]): if any(i == get_node_name(node) for i in ["reshape", "view"]):
reshape_args = flat_list(node.args[1:]) reshape_args = flat_list(node.args[1:])
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = "" new_shape = ""
...@@ -443,3 +408,62 @@ class TraceFlow(object): ...@@ -443,3 +408,62 @@ class TraceFlow(object):
reshape_size[node.name] = [origin_shape, new_shape] reshape_size[node.name] = [origin_shape, new_shape]
chunk_info["reshape_size"] = reshape_size chunk_info["reshape_size"] = reshape_size
return chunk_info return chunk_info
def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.trace_indice.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx):
continue
# flow search
chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool:
"""
check if region start and end is legal
"""
# dim cannot be None
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
return False
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
return False
# must have users
if len(end_node.users) == 0:
return False
# check index source align
if not self.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
return False
# check index copmute
if not self.check_index_compute(start_idx, end_dim, end_node, end_idx):
return False
return True
This diff is collapsed.
from typing import Any, Callable, Dict, Iterable, List, Tuple from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
NON_COMPUTE_OP = ["placeholder", "get_attr", "output"]
NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"]
logger = get_dist_logger() logger = get_dist_logger()
def get_logger(): def get_logger() -> Any:
return logger return logger
...@@ -37,7 +39,7 @@ def find_first_tensor_arg(node: Node) -> Node: ...@@ -37,7 +39,7 @@ def find_first_tensor_arg(node: Node) -> Node:
def is_non_compute_node(node: Node) -> bool: def is_non_compute_node(node: Node) -> bool:
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]): if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
return True return True
if "getitem" in node.name: if "getitem" in node.name:
node_args = flat_list(node.args[1:]) node_args = flat_list(node.args[1:])
...@@ -64,33 +66,33 @@ def is_non_memory_node(node: Node) -> bool: ...@@ -64,33 +66,33 @@ def is_non_memory_node(node: Node) -> bool:
return is_non_compute_node(node) return is_non_compute_node(node)
def is_non_compute_node_except_placeholder(node): def is_non_compute_node_except_placeholder(node: Node) -> bool:
if "placeholder" in node.op: if "placeholder" in node.op:
return False return False
return is_non_compute_node(node) return is_non_compute_node(node)
def is_non_compute_node_except_placeholder_output(node): def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
if "output" in node.op: if "output" in node.op:
return False return False
return is_non_compute_node_except_placeholder(node) return is_non_compute_node_except_placeholder(node)
def find_idx_by_name(name, nodes_list): def find_idx_by_name(name: str, nodes_list: List) -> int:
for idx, node in enumerate(nodes_list): for idx, node in enumerate(nodes_list):
if node.name == name: if node.name == name:
return idx return idx
raise RuntimeError("name %s not found in node list" % name) raise RuntimeError("name %s not found in node list" % name)
def delete_free_var_from_last_use(user_to_last_uses): def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
for key, value in user_to_last_uses.items(): for key, value in user_to_last_uses.items():
for n in value: for n in value:
if n.op == "placeholder": if n.op == "placeholder":
user_to_last_uses[key].remove(n) user_to_last_uses[key].remove(n)
def find_chunk_all_input_nodes(nodes: List[Node]): def find_chunk_all_input_nodes(nodes: List[Node]) -> List:
""" """
Find non-compute input and output node names. Find non-compute input and output node names.
input nodes are nodes used in the list input nodes are nodes used in the list
...@@ -104,7 +106,7 @@ def find_chunk_all_input_nodes(nodes: List[Node]): ...@@ -104,7 +106,7 @@ def find_chunk_all_input_nodes(nodes: List[Node]):
return input_nodes return input_nodes
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]): def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]:
""" """
Find non-compute input and output node names. Find non-compute input and output node names.
input nodes are nodes used in the list input nodes are nodes used in the list
...@@ -130,3 +132,33 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]): ...@@ -130,3 +132,33 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
output_nodes.append(node) output_nodes.append(node)
return input_nodes, output_nodes return input_nodes, output_nodes
def get_module_node_name(node: Node) -> str:
"""
get module class name
"""
node_targets = node.target.split(".")
module = node.graph.owning_module
for i in node_targets:
module = getattr(module, i)
module_name = str(module.__class__).split(".")[-1][:-2]
module_name = module_name.lower()
return module_name
def get_node_name(node: Node) -> str:
"""
get node name
"""
node_name = node.name
if "_" in node_name:
for i in range(len(node_name) - 1, -1, -1):
if node_name[i] == "_":
node_name = node_name[:i]
break
elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]:
continue
else:
break
return node_name
import time
import torch
import torch.fx
from simple_evoformer import base_evoformer, openfold_evoformer
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2
loop = 3
with torch.no_grad():
for _ in range(loop // 2 + 1):
if chunk_size:
model(node, pair, chunk_size)
else:
model(node, pair)
torch.cuda.synchronize()
time1 = time.time()
for _ in range(loop):
if chunk_size:
model(node, pair, chunk_size)
else:
model(node, pair)
torch.cuda.synchronize()
time2 = time.time()
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("%s: time %.4fs, mem %dMB" % (title, (time2 - time1) / loop, new_max_mem - now_mem))
def _build_autochunk(model, max_memory, node, pair):
# trace the module and replace codegen
graph = ColoTracer().trace(
model,
meta_args={
"node": node.to(torch.device("meta")),
"pair": pair.to(torch.device("meta")),
},
)
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
interp = MetaInfoProp(gm_prop)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
# now run it twice to get meta info in graph module, not necessary
gm = torch.fx.GraphModule(model, graph)
interp = MetaInfoProp(gm)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
# set code_gen
codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph)
gm.recompile()
# print
# code = graph.python_code("self").src
# print(code)
return gm
def benchmark_evoformer():
# init data and model
msa_len = 128
pair_len = 256
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = base_evoformer().cuda()
# build autochunk model
# max_memory = 1000 # MB, fit memory mode
max_memory = None # min memory mode
autochunk = _build_autochunk(base_evoformer().cuda(), max_memory, node, pair)
# build openfold
chunk_size = 64
openfold = openfold_evoformer().cuda()
# benchmark
_benchmark_evoformer(model, node, pair, "base")
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
_benchmark_evoformer(autochunk, node, pair, "autochunk")
if __name__ == "__main__":
benchmark_evoformer()
from typing import Any, Dict, List
import torch
import torch.fx
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.autochunk.utils import flat_list
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def assert_codegen_run(
model: Any,
meta_args: List,
concrete_args: List = None,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
if concrete_args is None:
concrete_args = []
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
interp = MetaInfoProp(meta_graph)
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert chunk in code
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_result = None; chunk_size = None;" in code
# assert result
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
model.cuda()
with torch.no_grad():
out_gm = gm(*inputs)
out_model = model(*inputs)
out_gm = flat_list(out_gm)
out_model = flat_list(out_model)
for out_gm_i, out_model_i in zip(out_gm, out_model):
assert torch.allclose(out_gm_i, out_model_i,
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(out_gm_i - out_model_i))
return chunks
def run_test(
rank: int,
data_args: tuple,
max_memory: int,
get_model: Any,
get_data: Any,
print_code: bool,
print_mem: bool,
print_progress: bool,
get_chunk_target: Any = None,
) -> None:
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = get_model()
meta_args, concrete_args = get_data(*data_args)
chunks = assert_codegen_run(
model,
meta_args=meta_args,
concrete_args=concrete_args,
max_memory=max_memory,
print_code=print_code,
print_mem=print_mem,
print_progress=print_progress,
)
if get_chunk_target is not None:
chunk_found = [i["region"] for i in chunks]
chunk_target = get_chunk_target()[max_memory]
assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % (
str(chunk_found),
str(chunk_target),
)
from functools import partial
from typing import Dict, List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerBlock
HAS_REPO = True
except:
HAS_REPO = False
from test_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
def get_model():
model = EvoformerBlock(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.15,
inf=1e4,
eps=1e-4,
is_multimer=False,
).eval().cuda()
return model
def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
meta_args = [
("m", node),
("z", pair),
("msa_mask", node_mask),
("pair_mask", pair_mask),
]
concrete_args = [("chunk_size", None), ("_mask_trans", True)]
return meta_args, concrete_args
def get_chunk_target() -> Dict:
return {
None: [(118, 123), (219, 237), (264, 289), (302, 309), (97, 104), (144, 152), (185, 193), (241, 242), (21, 46)],
20: [(118, 123), (230, 237), (275, 282), (305, 306), (100, 101), (32, 39), (73, 79)],
24: [(118, 123)],
}
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_evoformer_block(data_args, max_memory):
run_func = partial(
run_test,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
)
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerStack
HAS_REPO = True
except:
HAS_REPO = False
from test_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
def get_model():
model = EvoformerStack(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
c_s=384,
no_heads_msa=8,
no_heads_pair=4,
no_blocks=2, # 48
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.25,
blocks_per_ckpt=None,
inf=1000000000.0,
eps=1e-08,
clear_cache_between_blocks=False,
is_multimer=False,
).eval().cuda()
return model
def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
meta_args = [
("m", node),
("z", pair),
("msa_mask", node_mask),
("pair_mask", pair_mask),
]
concrete_args = [("chunk_size", None), ("_mask_trans", True)]
return meta_args, concrete_args
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_evoformer_stack(data_args, max_memory):
run_func = partial(
run_test,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
)
from functools import partial
from typing import Dict, List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import ExtraMSABlock
HAS_REPO = True
except:
HAS_REPO = False
from test_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
def get_model():
model = ExtraMSABlock(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.15,
inf=1e4,
eps=1e-4,
ckpt=False,
is_multimer=False,
).eval().cuda()
return model
def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
meta_args = [
("m", node),
("z", pair),
("msa_mask", node_mask),
("pair_mask", pair_mask),
]
concrete_args = [("chunk_size", None), ("_chunk_logits", 1024)]
return meta_args, concrete_args
def get_chunk_target() -> Dict:
return {
None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250),
(33, 46)],
20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)],
24: [(126, 131)],
}
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_extramsa_block(data_args, max_memory):
run_func = partial(
run_test,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
print_code=False,
print_mem=False,
print_progress=False,
)
from functools import partial from typing import Any, Dict, List
import pytest
import torch import torch
import torch.fx import torch.fx
import torch.multiprocessing as mp
try:
from simple_evoformer import base_evoformer
HAS_REPO = True
except:
HAS_REPO = False
import colossalai import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer, symbolic_trace
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta(): if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def assert_codegen_run(
model: Any,
meta_args: List,
concrete_args: List = None,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
if concrete_args is None:
concrete_args = []
model = model()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
interp = MetaInfoProp(meta_graph)
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model.cuda(),
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): # assert chunk in code
with torch.no_grad(): code = graph.python_code("self").src
non_fx_out = model(node, pair) if print_code:
fx_out = gm(node, pair) print(code)
assert "chunk_result = None; chunk_size = None;" in code
assert torch.allclose(non_fx_out[0], fx_out[0], # assert result
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
torch.abs(non_fx_out[0] - fx_out[0])) model.cuda().eval()
assert torch.allclose(non_fx_out[1], fx_out[1], gm.eval()
with torch.no_grad():
out_gm = gm(*inputs)
out_model = model(*inputs)
assert torch.allclose(out_gm["sample"], out_model["sample"],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[1] - fx_out[1])) torch.abs(out_gm["sample"] - out_model["sample"]))
return chunks
def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory): def run_test(
rank: int,
model: Any,
data: tuple,
max_memory: int,
print_code: bool,
print_mem: bool,
print_progress: bool,
get_chunk_target: Any = None,
) -> None:
# launch colossalai # launch colossalai
colossalai.launch( colossalai.launch(
config={}, config={},
...@@ -50,55 +98,23 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory): ...@@ -50,55 +98,23 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
) )
# build model and input # build model and input
model = base_evoformer().cuda() meta_args, concrete_args = data
node = torch.randn(1, msa_len, pair_len, 256).cuda() chunks = assert_codegen_run(
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
# meta info prop
meta_graph = symbolic_trace(model,
meta_args={
"node": node.to(torch.device("meta")),
"pair": pair.to(torch.device("meta")),
}) # must use symbolic_trace
interp = MetaInfoProp(meta_graph)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
# trace the module and replace codegen
graph = ColoTracer().trace(
model, model,
meta_args={ meta_args=meta_args,
"node": node.to(torch.device("meta")), concrete_args=concrete_args,
"pair": pair.to(torch.device("meta")),
},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair)
gpc.destroy()
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason='torch version is lower than 1.12.0')
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_simple_evoformer_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_simple_evoformer_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory, max_memory=max_memory,
print_code=print_code,
print_mem=print_mem,
print_progress=print_progress,
) )
mp.spawn(run_func, nprocs=1)
if get_chunk_target is not None:
chunk_found = [i["region"] for i in chunks]
chunk_target = get_chunk_target()[max_memory]
assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
str(chunk_found),
str(chunk_target),
)
if __name__ == "__main__": gpc.destroy()
_test_simple_evoformer_codegen(0, 32, 64, 25)
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from diffusers import UNet2DModel
MODELS = [UNet2DModel]
HAS_REPO = True
except:
MODELS = []
HAS_REPO = False
from test_diffuser_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
BATCH_SIZE = 2
SEQ_LENGTH = 5
HEIGHT = 224
WIDTH = 224
IN_CHANNELS = 3
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
def get_data(shape: tuple) -> Tuple[List, List]:
sample = torch.randn(shape)
meta_args = [
("sample", sample),
]
concrete_args = [("timestep", 50)]
return meta_args, concrete_args
@pytest.mark.skipif(
True,
reason="not implemented",
)
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
@pytest.mark.parametrize("max_memory", [64])
def test_evoformer_block(model, shape, max_memory):
run_func = partial(
run_test,
max_memory=max_memory,
model=model,
data=get_data(shape),
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data=get_data(LATENTS_SHAPE),
max_memory=64,
model=UNet2DModel,
print_code=False,
print_mem=False,
print_progress=False,
)
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerBlock
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model = model.cuda()
with torch.no_grad():
non_fx_out = model(node, pair, node_mask, pair_mask)
fx_out = gm(node, pair, node_mask, pair_mask)
assert torch.allclose(non_fx_out[0], fx_out[0],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[0] - fx_out[0]))
assert torch.allclose(non_fx_out[1], fx_out[1],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[1] - fx_out[1]))
def _build_openfold():
model = EvoformerBlock(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.15,
inf=1e4,
eps=1e-4,
is_multimer=False,
).eval().cuda()
return model
def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = _build_openfold()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
interp = MetaInfoProp(meta_graph)
interp.propagate(
MetaTensor(node, fake_device="cuda:0"),
MetaTensor(pair, fake_device="cuda:0"),
MetaTensor(node_mask, fake_device="cuda:0"),
MetaTensor(pair_mask, fake_device="cuda:0"),
)
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
gpc.destroy()
@pytest.mark.skipif(
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_evoformer_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_evoformer_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_evoformer_codegen(0, 32, 64, 24)
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerStack
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1, None)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model = model.cuda()
with torch.no_grad():
non_fx_out = model(node, pair, node_mask, pair_mask, None)
fx_out = gm(node, pair, node_mask, pair_mask, None)
assert torch.allclose(non_fx_out[0], fx_out[0],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[0] - fx_out[0]))
assert torch.allclose(non_fx_out[1], fx_out[1],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[1] - fx_out[1]))
def _build_openfold():
model = EvoformerStack(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
c_s=384,
no_heads_msa=8,
no_heads_pair=4,
no_blocks=2, # 48
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.25,
blocks_per_ckpt=None,
inf=1000000000.0,
eps=1e-08,
clear_cache_between_blocks=False,
is_multimer=False,
).eval().cuda()
return model
def _test_evoformer_stack_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = _build_openfold()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
interp = MetaInfoProp(meta_graph)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"),
MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), None)
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False, print_progress=False)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
gpc.destroy()
@pytest.mark.skipif(
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_evoformer_stack_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_evoformer_stack_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_evoformer_stack_codegen(0, 32, 64, None)
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import ExtraMSABlock
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model = model.cuda()
with torch.no_grad():
non_fx_out = model(node, pair, node_mask, pair_mask)
fx_out = gm(node, pair, node_mask, pair_mask)
assert torch.allclose(non_fx_out[0], fx_out[0],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[0] - fx_out[0]))
assert torch.allclose(non_fx_out[1], fx_out[1],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[1] - fx_out[1]))
def _build_openfold():
model = ExtraMSABlock(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.15,
inf=1e4,
eps=1e-4,
ckpt=False,
is_multimer=False,
).eval().cuda()
return model
def _test_extramsa_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = _build_openfold()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_chunk_logits": 1024,
},
)
interp = MetaInfoProp(meta_graph)
interp.propagate(
MetaTensor(node, fake_device="cuda:0"),
MetaTensor(pair, fake_device="cuda:0"),
MetaTensor(node_mask, fake_device="cuda:0"),
MetaTensor(pair_mask, fake_device="cuda:0"),
)
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_chunk_logits": 1024,
},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
gpc.destroy()
@pytest.mark.skipif(
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_extramsa_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_extramsa_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_extramsa_codegen(0, 32, 64, None)
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from simple_evoformer import base_evoformer
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx import symbolic_trace
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
found_regions = [i["region"] for i in chunk_infos]
if msa_len == 32 and pair_len == 64:
if max_memory is None:
target_regions = [(142, 154), (366, 373), (234, 283), (302, 351), (127, 134), (211, 228), (174, 191),
(161, 166), (198, 203), (7, 57)]
elif max_memory == 20:
target_regions = [(142, 154), (369, 373), (235, 269), (303, 351), (130, 131)]
elif max_memory == 25:
target_regions = [(144, 154), (369, 370)]
elif max_memory == 30:
target_regions = [(144, 154)]
else:
raise NotImplementedError()
else:
raise NotImplementedError()
assert found_regions == target_regions, "found regions %s doesn't equal target regions %s" % (
str(found_regions),
str(target_regions),
)
def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = base_evoformer().cuda()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
meta_graph = symbolic_trace(model,
meta_args={
"node": node.to(torch.device("meta")),
"pair": pair.to(torch.device("meta")),
}) # must use symbolic_trace
interp = MetaInfoProp(meta_graph)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
chunk_infos = codegen.chunk_infos
assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len)
gpc.destroy()
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0")
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_simple_evoformer_search(msa_len, pair_len, max_memory):
run_func = partial(
_test_simple_evoformer_search,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_simple_evoformer_search(0, 32, 64, 20)
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from transformers import GPT2Config, GPT2Model
MODELS = [GPT2Model]
HAS_REPO = True
except:
MODELS = []
HAS_REPO = False
from test_transformer_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
BATCH_SIZE = 2
SEQ_LENGTH = 256
def get_data(shape: tuple) -> Tuple[List, List]:
input_ids = torch.zeros(shape, dtype=torch.int64)
token_type_ids = torch.zeros(shape, dtype=torch.int64)
attention_mask = torch.ones(shape, dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
concrete_args = {"past_key_values": None}
sequence = ["input_ids", "past_key_values", "attention_mask", "token_type_ids"]
return meta_args, concrete_args, sequence
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@pytest.mark.parametrize("max_memory", [None, 4.5, 5])
def test_gpt(model, shape, max_memory):
run_func = partial(
run_test,
data=get_data(shape),
max_memory=max_memory,
model=model,
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
max_memory=None,
model=GPT2Model,
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
print_code=True,
print_mem=True,
print_progress=False,
)
from typing import Any, Dict, List
import torch
import torch.fx
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def assert_codegen_run(
model: Any,
data: tuple,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
meta_args, concrete_args, sequence = data
if concrete_args is None:
concrete_args = {}
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
concrete_args={k: v for k, v in concrete_args.items()},
)
interp = MetaInfoProp(meta_graph)
meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model.cuda(),
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
concrete_args={k: v for k, v in concrete_args.items()},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert chunk in code
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_result = None; chunk_size = None;" in code
# assert result
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
model.cuda().eval()
gm.eval()
with torch.no_grad():
out_gm = gm(*inputs)
out_model = model(*inputs)
for k in out_model.keys():
if torch.is_tensor(out_gm[k]):
assert torch.equal(
out_model[k], out_gm[k]
), f'{model.__class__.__name__} has incorrect output {k}, expect {out_model[k]}, but got {out_gm[k]}'
return chunks
def run_test(
rank: int,
model: Any,
config: Any,
data: tuple,
max_memory: int,
print_code: bool,
print_mem: bool,
print_progress: bool,
get_chunk_target: Any = None,
) -> None:
model = model(config=config)
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
chunks = assert_codegen_run(
model,
data=data,
max_memory=max_memory,
print_code=print_code,
print_mem=print_mem,
print_progress=print_progress,
)
if get_chunk_target is not None:
chunk_found = [i["region"] for i in chunks]
chunk_target = get_chunk_target()[max_memory]
assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
str(chunk_found),
str(chunk_target),
)
gpc.destroy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment