Commit 8208fd02 authored by jiaruifang's avatar jiaruifang
Browse files

Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into dev0116

parents 438ea608 d565a248
...@@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): ...@@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply, runtime_apply,
args=(node, origin_dict_node, input_dict_node, args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index)) node_to_index_dict[node], user_node_index))
if 'activation_checkpoint' in user_node.meta:
shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
new_args = list(user_node.args) new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs) new_kwargs = dict(user_node.kwargs)
...@@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): ...@@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node # substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs user.kwargs = new_kwargs
if 'activation_checkpoint' in node.meta:
comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
return gm
def _act_annotataion_pass(gm: torch.fx.GraphModule):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
for node in nodes:
if not hasattr(node.meta, 'activation_checkpoint'):
from .runtime_preparation_pass import size_processing
user_act_annotation = -1
input_act_annotation = -1
for user_node in node.users.keys():
if 'activation_checkpoint' in user_node.meta:
user_act_annotation = user_node.meta['activation_checkpoint']
break
for input_node in node._input_nodes.keys():
if 'activation_checkpoint' in input_node.meta:
input_act_annotation = input_node.meta['activation_checkpoint']
break
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
node.meta['activation_checkpoint'] = user_act_annotation
return gm return gm
......
...@@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): ...@@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# It will be used to replace the original node with processing node in slice object # It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data size_processing_node._meta_data = node._meta_data
if 'activation_checkpoint' in node.meta:
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
user_list = list(node.users.keys()) user_list = list(node.users.keys())
for user in user_list: for user in user_list:
......
...@@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import ( ...@@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
) )
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
...@@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module): ...@@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module):
into the forward function. into the forward function.
''' '''
def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]): origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
''' '''
Args: Args:
...@@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh): ...@@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
return strategies_constructor return strategies_constructor
def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
''' '''
This method is used to solve the best solution for the given graph. This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node. The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
...@@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, ...@@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor,
return solution return solution
def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh, def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor): strategies_constructor: StrategiesConstructor):
''' '''
This method is used to transform the original graph to the sharded graph. This method is used to transform the original graph to the sharded graph.
...@@ -197,10 +198,10 @@ def initialize_model(model: nn.Module, ...@@ -197,10 +198,10 @@ def initialize_model(model: nn.Module,
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies. return a series of integers, but return the best strategies.
''' '''
tracer = ColoTracer() tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(root=model, meta_args=meta_args) graph = tracer.trace(root=model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__) gm = ColoGraphModule(model, graph, model.__class__.__name__)
gm.recompile() gm.recompile()
strategies_constructor = build_strategy_constructor(graph, device_mesh) strategies_constructor = build_strategy_constructor(graph, device_mesh)
if load_solver_solution: if load_solver_solution:
......
...@@ -48,9 +48,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> ...@@ -48,9 +48,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
return new_shape return new_shape
def _gen_loop_start( def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str:
chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2
) -> str:
""" """
Generate chunk loop start Generate chunk loop start
...@@ -72,9 +70,8 @@ def _gen_loop_start( ...@@ -72,9 +70,8 @@ def _gen_loop_start(
out_shape = get_node_shape(chunk_output) out_shape = get_node_shape(chunk_output)
out_str = str(list(out_shape)) out_str = str(list(out_shape))
context = ( context = (
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" %
% (out_str, input_node.name, input_node.name, chunk_size) (out_str, input_node.name, input_node.name, chunk_size))
)
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
return context return context
...@@ -105,26 +102,17 @@ def _gen_loop_end( ...@@ -105,26 +102,17 @@ def _gen_loop_end(
chunk_outputs_name = chunk_outputs.name chunk_outputs_name = chunk_outputs.name
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list) chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
chunk_slice = _gen_chunk_slice_dim( chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
chunk_outputs_dim, "chunk_idx", chunk_output_shape
)
context = " chunk_result%s = %s; %s = None\n" % ( context = " chunk_result%s = %s; %s = None\n" % (
chunk_slice, chunk_slice,
chunk_outputs_name, chunk_outputs_name,
chunk_outputs_name, chunk_outputs_name,
) )
context += ( context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
)
# determine if its the last use for chunk input # determine if its the last use for chunk input
for chunk_input in chunk_inputs + chunk_non_compute_inputs: for chunk_input in chunk_inputs + chunk_non_compute_inputs:
if all( if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
[
find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
for user in chunk_input.users.keys()
]
):
context += "; %s = None" % chunk_input.name context += "; %s = None" % chunk_input.name
context += "\n" context += "\n"
...@@ -171,17 +159,10 @@ def _replace_ones_like( ...@@ -171,17 +159,10 @@ def _replace_ones_like(
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1: if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0] source_node = meta_node.args[0].args[0]
if ( if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
source_node not in chunk_infos[region_idx]["node_chunk_dim"] or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
is None body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
):
chunk_slice = _gen_chunk_slice_dim(
chunk_dim, "chunk_idx", get_node_shape(node)
)
body[-1] = _replace_name(
body[-1], node.args[0].name, node.args[0].name + chunk_slice
)
return body return body
...@@ -198,12 +179,8 @@ def _replace_input_node( ...@@ -198,12 +179,8 @@ def _replace_input_node(
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
if idx == node_idx: if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim( chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node))
dim[0], "chunk_idx", get_node_shape(input_node) body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice)
)
body[-1] = _replace_name(
body[-1], input_node.name, input_node.name + chunk_slice
)
return body return body
...@@ -237,13 +214,9 @@ def emit_code_with_chunk( ...@@ -237,13 +214,9 @@ def emit_code_with_chunk(
# chunk inputs # chunk inputs
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
chunk_inputs_non_chunk = [ chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
i["inputs_non_chunk"] for i in chunk_infos
] # input without chunk
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
j.name for i in chunk_inputs_non_chunk for j in i
]
# chunk outputs # chunk outputs
chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_outputs = [i["outputs"][0] for i in chunk_infos]
...@@ -267,23 +240,16 @@ def emit_code_with_chunk( ...@@ -267,23 +240,16 @@ def emit_code_with_chunk(
chunk_outputs[region_idx], chunk_outputs[region_idx],
chunk_outputs_dim[region_idx], chunk_outputs_dim[region_idx],
chunk_infos[region_idx]["chunk_size"], chunk_infos[region_idx]["chunk_size"],
) ))
)
if within_chunk_region: if within_chunk_region:
emit_node_func(node, body) emit_node_func(node, body)
# replace input var with chunk var # replace input var with chunk var
body = _replace_input_node( body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body)
chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body
)
# ones like # ones like
body = _replace_ones_like( body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
search_chunk, chunk_infos, region_idx, node_idx, node, body
)
# reassgin reshape size # reassgin reshape size
body[-1] = _replace_reshape_size( body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
)
body[-1] = " " + body[-1] body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names) delete_unused_value_func(node, body, chunk_inputs_names)
else: else:
...@@ -300,8 +266,7 @@ def emit_code_with_chunk( ...@@ -300,8 +266,7 @@ def emit_code_with_chunk(
chunk_outputs[region_idx], chunk_outputs[region_idx],
chunk_outputs_dim[region_idx], chunk_outputs_dim[region_idx],
node_list, node_list,
) ))
)
within_chunk_region = False within_chunk_region = False
node_idx += 1 node_idx += 1
...@@ -310,18 +275,14 @@ def emit_code_with_chunk( ...@@ -310,18 +275,14 @@ def emit_code_with_chunk(
if CODEGEN_AVAILABLE: if CODEGEN_AVAILABLE:
class AutoChunkCodeGen(CodeGen): class AutoChunkCodeGen(CodeGen):
def __init__(self, meta_graph, max_memory=None, print_mem=False): def __init__(self, meta_graph, max_memory=None, print_mem=False):
super().__init__() super().__init__()
self.meta_graph = meta_graph
self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions # find the chunk regions
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
self.chunk_infos = self.search_chunk.search_region() self.chunk_infos = self.search_chunk.search_region()
def _gen_python_code( def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
self, nodes, root_module: str, namespace: _Namespace
) -> PythonCode:
free_vars: List[str] = [] free_vars: List[str] = []
body: List[str] = [] body: List[str] = []
globals_: Dict[str, Any] = {} globals_: Dict[str, Any] = {}
...@@ -338,9 +299,7 @@ if CODEGEN_AVAILABLE: ...@@ -338,9 +299,7 @@ if CODEGEN_AVAILABLE:
Returns: the global name that should be used to reference 'obj' in generated source. Returns: the global name that should be used to reference 'obj' in generated source.
""" """
if ( if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
_is_from_torch(obj) and obj != torch.device
): # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We # HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their # can't import them like normal modules so they must retain their
# fully qualified name. # fully qualified name.
...@@ -356,9 +315,7 @@ if CODEGEN_AVAILABLE: ...@@ -356,9 +315,7 @@ if CODEGEN_AVAILABLE:
return global_name return global_name
# set _custom_builtins here so that we needn't import colossalai in forward # set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin( _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
"import colossalai", colossalai
)
# Pre-fill the globals table with registered builtins. # Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items(): for name, (_, obj) in _custom_builtins.items():
...@@ -394,9 +351,8 @@ if CODEGEN_AVAILABLE: ...@@ -394,9 +351,8 @@ if CODEGEN_AVAILABLE:
# Common case: this is a regular module name like 'foo.bar.baz' # Common case: this is a regular module name like 'foo.bar.baz'
return add_global(typename, o) return add_global(typename, o)
def _format_args( def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
) -> str:
def _get_repr(arg): def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global. # Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"): if isinstance(arg, tuple) and hasattr(arg, "_fields"):
...@@ -444,26 +400,18 @@ if CODEGEN_AVAILABLE: ...@@ -444,26 +400,18 @@ if CODEGEN_AVAILABLE:
nodes_to_delete = user_to_last_uses.get(user, []) nodes_to_delete = user_to_last_uses.get(user, [])
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
if len(nodes_to_delete): if len(nodes_to_delete):
to_delete_str = " = ".join( to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
[repr(n) for n in nodes_to_delete] + ["None"]
)
body.append(f"; {to_delete_str}\n") body.append(f"; {to_delete_str}\n")
else: else:
body.append("\n") body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func # NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body): def emit_node(node: Node, body):
maybe_type_annotation = ( maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
"" if node.type is None else f" : {type_repr(node.type)}"
)
if node.op == "placeholder": if node.op == "placeholder":
assert isinstance(node.target, str) assert isinstance(node.target, str)
maybe_default_arg = ( maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
"" if not node.args else f" = {repr(node.args[0])}" free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
)
free_vars.append(
f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
)
raw_name = node.target.replace("*", "") raw_name = node.target.replace("*", "")
if raw_name != repr(node): if raw_name != repr(node):
body.append(f"{repr(node)} = {raw_name}\n") body.append(f"{repr(node)} = {raw_name}\n")
...@@ -472,68 +420,46 @@ if CODEGEN_AVAILABLE: ...@@ -472,68 +420,46 @@ if CODEGEN_AVAILABLE:
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append( body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})" f"({_format_args(node.args[1:], node.kwargs)})")
)
return return
elif node.op == "call_function": elif node.op == "call_function":
assert callable(node.target) assert callable(node.target)
# pretty print operators # pretty print operators
if ( if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
node.target.__module__ == "_operator"
and node.target.__name__ in magic_methods
):
assert isinstance(node.args, tuple) assert isinstance(node.args, tuple)
body.append( body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{repr(node)}{maybe_type_annotation} = " f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return return
# pretty print inplace operators; required for jit.script to work properly # pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo # not currently supported in normal FX graphs, but generated by torchdynamo
if ( if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
node.target.__module__ == "_operator" body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
and node.target.__name__ in inplace_methods f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
):
body.append(
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
)
return return
qualified_name = _get_qualified_name(node.target) qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target) global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument # special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if ( if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
global_name == "getattr" and node.args[1].isidentifier() and len(node.args) == 2):
and isinstance(node.args, tuple)
and isinstance(node.args[1], str)
and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append( body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
)
return return
body.append( body.append(
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
)
if node.meta.get("is_wrapped", False): if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name) wrapped_fns.setdefault(global_name)
return return
elif node.op == "call_module": elif node.op == "call_module":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append( body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{repr(node)}{maybe_type_annotation} = " f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return return
elif node.op == "get_attr": elif node.op == "get_attr":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append( body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}"
)
return return
elif node.op == "output": elif node.op == "output":
if node.type is not None: if node.type is not None:
...@@ -564,9 +490,7 @@ if CODEGEN_AVAILABLE: ...@@ -564,9 +490,7 @@ if CODEGEN_AVAILABLE:
if len(wrapped_fns) > 0: if len(wrapped_fns) > 0:
wrap_name = add_global("wrap", torch.fx.wrap) wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = "\n".join( wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
[f'{wrap_name}("{name}")' for name in wrapped_fns]
)
else: else:
wrap_stmts = "" wrap_stmts = ""
......
...@@ -10,6 +10,7 @@ from .utils import ( ...@@ -10,6 +10,7 @@ from .utils import (
class TraceFlow(object): class TraceFlow(object):
def __init__(self, trace_indice: TraceIndice) -> None: def __init__(self, trace_indice: TraceIndice) -> None:
self.trace_indice = trace_indice self.trace_indice = trace_indice
...@@ -28,9 +29,7 @@ class TraceFlow(object): ...@@ -28,9 +29,7 @@ class TraceFlow(object):
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list) start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
end_node_trace = self.trace_indice._find_trace_from_node(end_node) end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_trace_source = end_node_trace["source"][end_dim] end_node_trace_source = end_node_trace["source"][end_dim]
sorted_source = sorted( sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
)
for node_idx, node_dim in sorted_source: for node_idx, node_dim in sorted_source:
if node_idx == start_node_idx and start_dim in node_dim: if node_idx == start_node_idx and start_dim in node_dim:
return True return True
...@@ -70,10 +69,8 @@ class TraceFlow(object): ...@@ -70,10 +69,8 @@ class TraceFlow(object):
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
node_trace_source = self.trace_indice._find_source_trace_from_node(node) node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))): for node_dim in range(len(get_node_shape(node))):
if ( if (input_node_idx in node_trace_source[node_dim]
input_node_idx in node_trace_source[node_dim] and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
and input_dim[0] in node_trace_source[node_dim][input_node_idx]
):
return node_dim return node_dim
return None return None
...@@ -81,15 +78,11 @@ class TraceFlow(object): ...@@ -81,15 +78,11 @@ class TraceFlow(object):
input_dim_after_node = {} input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
inherit_dim = self._find_inherit_dim( inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
input_node, v, self.trace_indice.node_list[k]
)
if inherit_dim: if inherit_dim:
input_dim_after_node[k] = inherit_dim input_dim_after_node[k] = inherit_dim
for node in self.trace_indice.node_list[ for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
]:
if is_non_compute_node_except_placeholder(node): if is_non_compute_node_except_placeholder(node):
continue continue
count = 0 count = 0
...@@ -159,9 +152,7 @@ class TraceFlow(object): ...@@ -159,9 +152,7 @@ class TraceFlow(object):
if arg_node in all_node_info: if arg_node in all_node_info:
if all_node_info[arg_node]["chunk_dim"] != arg_dim: if all_node_info[arg_node]["chunk_dim"] != arg_dim:
return False return False
all_node_info[arg_node]["fix_dim"] = list( all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
)
# else add it to list # else add it to list
else: else:
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
...@@ -170,9 +161,7 @@ class TraceFlow(object): ...@@ -170,9 +161,7 @@ class TraceFlow(object):
return True return True
def _get_all_node_info(self, end_dim, start_idx, end_idx): def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [ cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
self.trace_indice.node_list[end_idx]
] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0: while len(cur_node_list) > 0:
...@@ -183,12 +172,8 @@ class TraceFlow(object): ...@@ -183,12 +172,8 @@ class TraceFlow(object):
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
if cur_node_chunk_dim: if cur_node_chunk_dim:
cur_node_compute = self.trace_indice._find_compute_trace_from_node( cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
cur_node cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
)
cur_node_source = self.trace_indice._find_source_trace_from_node(
cur_node
)
else: else:
cur_node_compute = cur_node_source = None cur_node_compute = cur_node_source = None
...@@ -215,15 +200,9 @@ class TraceFlow(object): ...@@ -215,15 +200,9 @@ class TraceFlow(object):
return None return None
if len(arg_list) == 2: if len(arg_list) == 2:
if any(i in cur_node.name for i in ["add", "mul"]): if any(i in cur_node.name for i in ["add", "mul", "truediv"]):
for arg in arg_list: for arg in arg_list:
if not ( if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
start_idx
<= find_idx_by_name(
arg.name, self.trace_indice.node_list
)
< end_idx
):
continue continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"] arg_chunk_dim = all_node_info[arg]["chunk_dim"]
arg_fix_dim = all_node_info[arg]["fix_dim"] arg_fix_dim = all_node_info[arg]["fix_dim"]
...@@ -249,9 +228,7 @@ class TraceFlow(object): ...@@ -249,9 +228,7 @@ class TraceFlow(object):
remove_inputs = [] remove_inputs = []
for input_node in inputs: for input_node in inputs:
input_dict = {} input_dict = {}
input_node_idx = find_idx_by_name( input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
input_node.name, self.trace_indice.node_list
)
for user in input_node.users.keys(): for user in input_node.users.keys():
if is_non_compute_node(user): if is_non_compute_node(user):
continue continue
...@@ -259,9 +236,7 @@ class TraceFlow(object): ...@@ -259,9 +236,7 @@ class TraceFlow(object):
if start_idx <= user_idx <= end_idx: if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"] chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None: if chunk_dim is not None:
user_source = self.trace_indice._find_source_trace_from_node( user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
user
)[chunk_dim]
if input_node_idx in user_source: if input_node_idx in user_source:
input_dict[user_idx] = user_source[input_node_idx] input_dict[user_idx] = user_source[input_node_idx]
else: else:
...@@ -305,13 +280,8 @@ class TraceFlow(object): ...@@ -305,13 +280,8 @@ class TraceFlow(object):
if type(cur_prepose_node_arg) != type(cur_prepose_node): if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue continue
# out of loop # out of loop
if not ( if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
start_idx end_idx):
<= find_idx_by_name(
cur_prepose_node_arg.name, self.trace_indice.node_list
)
< end_idx
):
continue continue
# compute op in loop # compute op in loop
elif cur_prepose_node_arg in all_node_info: elif cur_prepose_node_arg in all_node_info:
...@@ -335,15 +305,13 @@ class TraceFlow(object): ...@@ -335,15 +305,13 @@ class TraceFlow(object):
if n in maybe_prepose_nodes: if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n) maybe_prepose_nodes.remove(n)
# sort by index # sort by index
prepose_nodes.sort( prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)
)
return prepose_nodes return prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop # we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1] chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
# also need to get some prepose node's arg out of non_chunk_inputs # also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]: for n in chunk_info["args"]["prepose_nodes"]:
chunk_node_list.remove(n) chunk_node_list.remove(n)
...@@ -354,9 +322,7 @@ class TraceFlow(object): ...@@ -354,9 +322,7 @@ class TraceFlow(object):
return chunk_info return chunk_info
def flow_search(self, start_idx, start_dim, end_idx, end_dim): def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes( inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
self.trace_indice.node_list[start_idx : end_idx + 1]
)
# only single ouput # only single ouput
if len(outputs) > 1: if len(outputs) > 1:
return None return None
...@@ -367,9 +333,7 @@ class TraceFlow(object): ...@@ -367,9 +333,7 @@ class TraceFlow(object):
return None return None
# get input nodes' chunk dim # get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim( inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
inputs, start_idx, end_idx, all_node_info
)
if inputs is None: if inputs is None:
return None return None
...@@ -385,9 +349,7 @@ class TraceFlow(object): ...@@ -385,9 +349,7 @@ class TraceFlow(object):
} }
# move useless nodes ahead of loop # move useless nodes ahead of loop
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes( chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)
all_node_info, start_idx, end_idx
)
# find non chunk inputs # find non chunk inputs
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
...@@ -400,10 +362,8 @@ class TraceFlow(object): ...@@ -400,10 +362,8 @@ class TraceFlow(object):
def _reassgin_reshape_size(self, chunk_info): def _reassgin_reshape_size(self, chunk_info):
chunk_region = chunk_info["region"] chunk_region = chunk_info["region"]
reshape_size = {} reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[ chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
chunk_info["outputs_dim"] for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
]
for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]): if any(i in node.name for i in ["reshape", "view"]):
reshape_args = node.args[1:] reshape_args = node.args[1:]
reshape_log = self.trace_indice.indice_view_list[node] reshape_log = self.trace_indice.indice_view_list[node]
...@@ -413,8 +373,6 @@ class TraceFlow(object): ...@@ -413,8 +373,6 @@ class TraceFlow(object):
if reshape_arg_dim in reshape_log["dim_to"]: if reshape_arg_dim in reshape_log["dim_to"]:
continue continue
if reshape_arg_dim == chunk_dim: if reshape_arg_dim == chunk_dim:
reshape_size[node.name][reshape_arg.name] = ( reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape)
"min(chunk_size, %d - chunk_idx)" % chunk_shape
)
chunk_info["reshape_size"] = reshape_size chunk_info["reshape_size"] = reshape_size
return chunk_info return chunk_info
...@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple ...@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
from torch.fx.node import Node from torch.fx.node import Node
from .utils import find_idx_by_name, get_node_shape from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list
class TraceIndice(object): class TraceIndice(object):
...@@ -79,9 +79,7 @@ class TraceIndice(object): ...@@ -79,9 +79,7 @@ class TraceIndice(object):
node_from_trace = self._find_trace_from_node(node_from) node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to) node_to_trace = self._find_trace_from_node(node_to)
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim] node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
node_to_trace["compute"][node_to_dim] = copy.deepcopy( node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
node_from_trace["compute"][node_from_dim]
)
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True) self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
def _inherit_all_computation(self, node_from, node_to): def _inherit_all_computation(self, node_from, node_to):
...@@ -209,7 +207,7 @@ class TraceIndice(object): ...@@ -209,7 +207,7 @@ class TraceIndice(object):
node_idx (int) node_idx (int)
""" """
if input_node == None: if input_node == None:
input_node = node.args[0] input_node = find_first_tensor_arg(node)
input_node_idx = find_idx_by_name(input_node.name, self.node_list) input_node_idx = find_idx_by_name(input_node.name, self.node_list)
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"] input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
...@@ -227,6 +225,8 @@ class TraceIndice(object): ...@@ -227,6 +225,8 @@ class TraceIndice(object):
node_idx (int) node_idx (int)
""" """
shape = node.meta["tensor_meta"].shape shape = node.meta["tensor_meta"].shape
if shape is None:
return
new_trace = [] new_trace = []
for _ in shape: for _ in shape:
new_trace.append(self._add_indice()) new_trace.append(self._add_indice())
...@@ -259,7 +259,7 @@ class TraceIndice(object): ...@@ -259,7 +259,7 @@ class TraceIndice(object):
node (node) node (node)
node_idx (int) node_idx (int)
""" """
permute_dim = node.args[1:] permute_dim = unflat_list(node.args[1:])
input_node = node.args[0] input_node = node.args[0]
self._assign_indice_as_input(node, node_idx, input_node) self._assign_indice_as_input(node, node_idx, input_node)
...@@ -359,6 +359,15 @@ class TraceIndice(object): ...@@ -359,6 +359,15 @@ class TraceIndice(object):
left, right = patterns.split("->") left, right = patterns.split("->")
left = left.split(",") left = left.split(",")
if '...' in right:
replace_list = "!@#$%^&*"
target_len = len(get_node_shape(node))
add_len = target_len - len(right) + 3
replace_str = replace_list[:add_len]
right = right.replace("...", replace_str)
for ll in range(len(left)):
left[ll] = left[ll].replace("...", replace_str)
all_index = [] all_index = []
for i in left: for i in left:
for c in i: for c in i:
...@@ -369,9 +378,7 @@ class TraceIndice(object): ...@@ -369,9 +378,7 @@ class TraceIndice(object):
for left_idx, left_str in enumerate(left): for left_idx, left_str in enumerate(left):
if right_indice in left_str: if right_indice in left_str:
source_idx = left_str.index(right_indice) source_idx = left_str.index(right_indice)
self._inherit_indice( self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx)
input_nodes[left_idx], source_idx, node, right_idx
)
def _assign_softmax_indice(self, node, idx): def _assign_softmax_indice(self, node, idx):
""" """
...@@ -440,11 +447,12 @@ class TraceIndice(object): ...@@ -440,11 +447,12 @@ class TraceIndice(object):
origin_node = node.args[0] origin_node = node.args[0]
origin_shape = origin_node.meta["tensor_meta"].shape origin_shape = origin_node.meta["tensor_meta"].shape
target_shape = [] target_shape = []
for i in range(1, len(node.args)): unflated_args = unflat_list(node.args)
if isinstance(node.args[i], int): for i in range(1, len(unflated_args)):
target_shape.append(node.args[i]) if isinstance(unflated_args[i], int):
target_shape.append(unflated_args[i])
else: else:
target_shape.append(node.args[i].meta["fwd_out"][0]) target_shape.append(unflated_args[i].meta["fwd_out"][0])
# compute the value of -1 # compute the value of -1
if -1 in target_shape: if -1 in target_shape:
...@@ -472,13 +480,7 @@ class TraceIndice(object): ...@@ -472,13 +480,7 @@ class TraceIndice(object):
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
self._del_dim(node_idx, -1) self._del_dim(node_idx, -1)
else: else:
raise NotImplementedError( raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")
"shape"
+ str(origin_shape)
+ "and"
+ str(target_shape)
+ "view not implemented"
)
# get new indice # get new indice
origin_trace = self._find_indice_trace_from_node(origin_node) origin_trace = self._find_indice_trace_from_node(origin_node)
...@@ -521,6 +523,8 @@ class TraceIndice(object): ...@@ -521,6 +523,8 @@ class TraceIndice(object):
self._assign_unsqueeze_indice(node, idx) self._assign_unsqueeze_indice(node, idx)
elif any(i in node.name for i in ["to", "contiguous"]): elif any(i in node.name for i in ["to", "contiguous"]):
self._assgin_no_change_indice(node, idx) self._assgin_no_change_indice(node, idx)
elif "new_ones" in node.name:
self._assign_ones_like_indice(node, idx)
else: else:
raise NotImplementedError(node.name, "method not implemented yet!") raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == "call_function": elif node.op == "call_function":
...@@ -530,7 +534,7 @@ class TraceIndice(object): ...@@ -530,7 +534,7 @@ class TraceIndice(object):
self._assign_matmul_indice(node, idx) self._assign_matmul_indice(node, idx)
elif "softmax" in node.name: elif "softmax" in node.name:
self._assign_softmax_indice(node, idx) self._assign_softmax_indice(node, idx)
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]): elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]):
self._assign_elementwise_indice(node, idx) self._assign_elementwise_indice(node, idx)
elif "ones_like" in node.name: elif "ones_like" in node.name:
self._assign_ones_like_indice(node, idx) self._assign_ones_like_indice(node, idx)
...@@ -538,17 +542,17 @@ class TraceIndice(object): ...@@ -538,17 +542,17 @@ class TraceIndice(object):
self._assign_dropout_indice(node, idx) self._assign_dropout_indice(node, idx)
elif "einsum" in node.name: elif "einsum" in node.name:
self._assign_einsum_indice(node, idx) self._assign_einsum_indice(node, idx)
elif "getattr" in node.name: elif "layer_norm" in node.name:
continue # get attr like shape self._assign_layernorm_indice(node, idx)
elif "getitem" in node.name: elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
continue # get item in list continue
else: else:
raise NotImplementedError( raise NotImplementedError(node.name, "function not implemented yet!")
node.name, "function not implemented yet!"
)
elif node.op == "call_module": elif node.op == "call_module":
if any(n in node.name for n in ["layernorm", "norm"]): if any(n in node.name for n in ["layernorm", "norm"]):
self._assign_layernorm_indice(node, idx) self._assign_layernorm_indice(node, idx)
elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]):
self._assign_elementwise_indice(node, idx)
else: else:
raise NotImplementedError(node.name, "module not implemented yet!") raise NotImplementedError(node.name, "module not implemented yet!")
elif node.op == "get_attr": elif node.op == "get_attr":
......
...@@ -3,10 +3,32 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple ...@@ -3,10 +3,32 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from torch.fx.node import Node from torch.fx.node import Node
def unflat_list(inputs):
"""
unflat a list by recursion
"""
res = []
for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
res.extend(unflat_list(i))
else:
res.append(i)
return res
def find_first_tensor_arg(node):
"""
Find the first input tensor arg for a node
"""
for arg in node.args:
if type(arg) == type(node):
return arg
raise RuntimeError()
def is_non_compute_node(node): def is_non_compute_node(node):
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
i in node.name for i in ["getitem", "getattr"] i in node.name for i in ["getitem", "getattr"]):
):
return True return True
return False return False
...@@ -18,17 +40,13 @@ def get_node_shape(node): ...@@ -18,17 +40,13 @@ def get_node_shape(node):
def is_non_compute_node_except_placeholder(node): def is_non_compute_node_except_placeholder(node):
if any(i in node.op for i in ["get_attr", "output"]) or any( if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]):
i in node.name for i in ["getitem", "getattr"]
):
return True return True
return False return False
def is_non_compute_node_except_placeholder_output(node): def is_non_compute_node_except_placeholder_output(node):
if any(i in node.op for i in ["get_attr"]) or any( if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]):
i in node.name for i in ["getitem", "getattr"]
):
return True return True
return False return False
...@@ -74,22 +92,16 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]): ...@@ -74,22 +92,16 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
# we treat that input node as the input of the checkpoint function # we treat that input node as the input of the checkpoint function
for node in nodes: for node in nodes:
for input_node in node._input_nodes.keys(): for input_node in node._input_nodes.keys():
if ( if (input_node not in nodes and input_node not in input_nodes
input_node not in nodes and not is_non_compute_node_except_placeholder(input_node)):
and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)
):
input_nodes.append(input_node) input_nodes.append(input_node)
# if a node has a user node which is not in the node list # if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output # we treat that user node as the node receiving the current node output
for node in nodes: for node in nodes:
for output_node in node.users.keys(): for output_node in node.users.keys():
if ( if (output_node not in nodes and node not in output_nodes
output_node not in nodes and not is_non_compute_node_except_placeholder_output(output_node)):
and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)
):
output_nodes.append(node) output_nodes.append(node)
return input_nodes, output_nodes return input_nodes, output_nodes
...@@ -249,6 +249,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -249,6 +249,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.sum.default, aten.sum.default,
aten.sum.dim_IntList, aten.sum.dim_IntList,
aten.mean.dim, aten.mean.dim,
aten.sub.Tensor,
aten.sub_.Tensor,
# activation op # activation op
aten.hardswish.default, aten.hardswish.default,
...@@ -313,7 +315,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -313,7 +315,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.where.self, aten.where.self,
aten.zero_.default, aten.zero_.default,
aten.zeros_like.default, aten.zeros_like.default,
] aten.fill_.Scalar
] # yapf: disable
for op in zero_flop_aten: for op in zero_flop_aten:
flop_mapping[op] = zero_flop_jit flop_mapping[op] = zero_flop_jit
......
...@@ -7,7 +7,6 @@ class BucketStore(BaseStore): ...@@ -7,7 +7,6 @@ class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup): def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg) super().__init__(torch_pg)
self._grads = dict()
self._params = dict() self._params = dict()
self._num_elements_in_bucket = dict() self._num_elements_in_bucket = dict()
...@@ -19,25 +18,24 @@ class BucketStore(BaseStore): ...@@ -19,25 +18,24 @@ class BucketStore(BaseStore):
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements self._num_elements_in_bucket[reduce_rank] += num_elements
def add_grad(self, tensor, reduce_rank: int = None):
self._grads[reduce_rank].append(tensor)
def add_param(self, tensor, reduce_rank: int = None): def add_param(self, tensor, reduce_rank: int = None):
self._params[reduce_rank].append(tensor) self._params[reduce_rank].append(tensor)
def reset(self): def reset(self):
keys = [None] + list(range(self._world_size)) keys = [None] + list(range(self._world_size))
self._grads = {rank: [] for rank in keys}
self._params = {rank: [] for rank in keys} self._params = {rank: [] for rank in keys}
self._num_elements_in_bucket = {rank: 0 for rank in keys} self._num_elements_in_bucket = {rank: 0 for rank in keys}
def reset_by_rank(self, reduce_rank=None): def reset_by_rank(self, reduce_rank=None):
self._grads[reduce_rank] = []
self._params[reduce_rank] = [] self._params[reduce_rank] = []
self._num_elements_in_bucket[reduce_rank] = 0 self._num_elements_in_bucket[reduce_rank] = 0
def get_grad(self, reduce_rank: int = None): def get_grad(self, reduce_rank: int = None):
return self._grads[reduce_rank] param_list = self.get_param(reduce_rank)
for param in param_list:
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
return [param.grad for param in param_list]
def get_param(self, reduce_rank: int = None): def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank] return self._params[reduce_rank]
...@@ -46,7 +46,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -46,7 +46,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
reduce_bucket_size: int = 1024 * 1024, # communication reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None, communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False, overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None): forced_dtype: Optional[torch.dtype] = None):
...@@ -248,9 +248,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -248,9 +248,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0]) self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
return params_per_rank return params_per_rank
########################################################### ###########################
# Backward Reduction Hook # Backward Reduction Hook #
########################################################### ###########################
def _grad_handler(self, param, grad, reduce_rank):
self._add_to_reduction_bucket(param, reduce_rank)
return grad
def _attach_reduction_hook(self): def _attach_reduction_hook(self):
# we iterate over the fp16 params # we iterate over the fp16 params
...@@ -268,51 +272,59 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -268,51 +272,59 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
else: else:
reduce_rank = None reduce_rank = None
def _define_and_attach(param, reduce_rank): param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
# get the AccumulateGrad object of the param itself
accum_grad_obj = get_grad_accumulate_object(param)
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
reduction_func = partial(self._reduce_and_remove_grads_by_bucket, def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
param=param, if self._overlap_communication:
reduce_rank=reduce_rank) torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
# define hook with torch.cuda.stream(stream):
# NOT IMPORTANT BUT GOOD TO KNOW: flat = bucket.flatten()
# args here is not grad, but allow_unreacable and accumulate_grad reduce_global_rank = None
def reduce_grad_hook(*args): if reduce_rank is not None:
reduction_func() reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
accum_grad_obj.register_hook(reduce_grad_hook) # update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
_define_and_attach(param, reduce_rank) def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None): for tensor in tensor_list:
param_size = param.numel() param_bucket.add_to_bucket(tensor, allow_oversize=True)
# check if the bucket is full if param_bucket.is_full_or_oversized():
# if full, will reduce the grads already in the bucket self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
# after reduction, the bucket will be empty param_bucket.empty()
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_in_bucket(reduce_rank)
# the param must not be reduced to ensure correctness if not param_bucket.is_empty():
is_param_reduced = self._param_store.is_param_reduced(param) self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
# the param must have grad for reduction def _reduce_grads(self, reduce_rank, grads, bucket_size):
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' grad_buckets_by_dtype = split_half_float_double(grads)
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) for tensor_list in grad_buckets_by_dtype:
self._bucket_store.add_grad(param.grad, reduce_rank) self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
self._bucket_store.add_param(param, reduce_rank) bucket_size=bucket_size,
reduce_rank=reduce_rank)
def _reduce_grads_in_bucket(self, reduce_rank=None): #######################
# Reduction Functions #
#######################
def _run_reduction(self, reduce_rank=None):
# reduce grads # reduce grads
self._reduce_grads_by_rank(reduce_rank=reduce_rank, self._reduce_grads(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
...@@ -351,50 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -351,50 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._bucket_store.reset_by_rank(reduce_rank) self._bucket_store.reset_by_rank(reduce_rank)
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): def _add_to_reduction_bucket(self, param, reduce_rank=None):
grad_buckets_by_dtype = split_half_float_double(grads) param_size = param.numel()
for tensor_list in grad_buckets_by_dtype:
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
##############################
# Reduction Utility Function #
##############################
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): # check if the bucket is full
if self._overlap_communication: # if full, will reduce the grads already in the bucket
torch.cuda.synchronize() # after reduction, the bucket will be empty
self._param_store.clear_grads_of_previous_reduced_params() if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
stream = self._comm_stream self._run_reduction(reduce_rank)
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream): # the param must not be reduced to ensure correctness
flat = bucket.flatten() is_param_reduced = self._param_store.is_param_reduced(param)
reduce_global_rank = None if is_param_reduced:
if reduce_rank is not None: msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
reduce_global_rank = self._dp_global_ranks[reduce_rank] + 'duplicate reduction will lead to arithmetic incorrectness'
reduced_flat = reduce_tensor_dp_group(tensor=flat, raise RuntimeError(msg)
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
# update the reduced tensor self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
if reduce_rank is None or reduce_rank == self._local_rank: self._bucket_store.add_param(param, reduce_rank)
bucket.unflatten_and_copy(reduced_flat)
################################ ################################
# torch.optim.Optimizer methods # torch.optim.Optimizer methods
...@@ -498,8 +484,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -498,8 +484,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# broadcast the updated model weights # broadcast the updated model weights
handles = [] handles = []
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
for rank in range(self._world_size): for index in range(self._world_size):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) rank = self._dp_global_ranks[index]
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True) handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle) handles.append(handle)
...@@ -585,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -585,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
param_group = self._fp16_param_groups[group_id] param_group = self._fp16_param_groups[group_id]
for param in param_group: for param in param_group:
if param.grad is not None: if param.grad is not None:
self._reduce_and_remove_grads_by_bucket(param) self._add_to_reduction_bucket(param)
# we need to reduce the gradients # we need to reduce the gradients
# left in the communication bucket # left in the communication bucket
self._reduce_grads_in_bucket() self._run_reduction()
def _reduce_grad_stage2(self): def _reduce_grad_stage2(self):
# when partition_grads is True, reduction hooks # when partition_grads is True, reduction hooks
...@@ -597,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -597,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# only need to reduce the gradients # only need to reduce the gradients
# left in the communication bucket # left in the communication bucket
for reduce_rank in range(self._world_size): for reduce_rank in range(self._world_size):
self._reduce_grads_in_bucket(reduce_rank) self._run_reduction(reduce_rank)
colossalai >= 0.1.12
torch >= 1.8.1
transformers >= 4.23 transformers >= 4.23
colossalai
colossalai >= 0.1.12
torch >= 1.8.1
from functools import partial
from typing import Optional, Tuple, Union
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
HIDDEN_SIZE = 16
class GPT2MLPWithCkpt(nn.Module):
def __init__(self, intermediate_size, hidden_size):
super().__init__()
embed_dim = hidden_size
self.c_fc = Conv1D(intermediate_size, embed_dim)
self.c_proj = Conv1D(embed_dim, intermediate_size)
self.act = torch.nn.ReLU()
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = checkpoint(self.c_proj, hidden_states)
hidden_states = self.act(hidden_states)
return hidden_states
def check_act_ckpt(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE)
input_sample = {
'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'),
}
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
gm = initialize_model(model, input_sample, device_mesh)
code = gm.module.graph.python_code('self').src
assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code
assert "view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)" in code
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_mlp_layer():
world_size = 4
run_func = partial(check_act_ckpt, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_mlp_layer()
...@@ -2,14 +2,13 @@ import time ...@@ -2,14 +2,13 @@ import time
import torch import torch
import torch.fx import torch.fx
from simple_evoformer import base_evoformer, openfold_evoformer
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
from tests.test_autochunk.evoformer.evoformer import evoformer_base
from tests.test_autochunk.openfold.evoformer import EvoformerBlock
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None): def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
...@@ -34,10 +33,7 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N ...@@ -34,10 +33,7 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N
time2 = time.time() time2 = time.time()
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print( print("%s: time %.4fs, mem %dMB" % (title, (time2 - time1) / loop, new_max_mem - now_mem))
"%s: time %.4fs, mem %dMB"
% (title, (time2 - time1) / loop, new_max_mem - now_mem)
)
def _build_autochunk(model, max_memory, node, pair): def _build_autochunk(model, max_memory, node, pair):
...@@ -52,16 +48,12 @@ def _build_autochunk(model, max_memory, node, pair): ...@@ -52,16 +48,12 @@ def _build_autochunk(model, max_memory, node, pair):
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
interp = MetaInfoProp(gm_prop) interp = MetaInfoProp(gm_prop)
interp.propagate( interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
)
# now run it twice to get meta info in graph module, not necessary # now run it twice to get meta info in graph module, not necessary
gm = torch.fx.GraphModule(model, graph) gm = torch.fx.GraphModule(model, graph)
interp = MetaInfoProp(gm) interp = MetaInfoProp(gm)
interp.propagate( interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
)
# set code_gen # set code_gen
codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False) codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False)
...@@ -75,42 +67,22 @@ def _build_autochunk(model, max_memory, node, pair): ...@@ -75,42 +67,22 @@ def _build_autochunk(model, max_memory, node, pair):
return gm return gm
def _build_openfold():
model = EvoformerBlock(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.15,
inf=1e4,
eps=1e-4,
is_multimer=False,
).cuda()
return model
def benchmark_evoformer(): def benchmark_evoformer():
# init data and model # init data and model
msa_len = 256 msa_len = 128
pair_len = 512 pair_len = 256
node = torch.randn(1, msa_len, pair_len, 256).cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = evoformer_base().cuda() model = base_evoformer().cuda()
# build autochunk model # build autochunk model
# max_memory = 1000 # MB, fit memory mode # max_memory = 1000 # MB, fit memory mode
max_memory = None # min memory mode max_memory = None # min memory mode
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) autochunk = _build_autochunk(base_evoformer().cuda(), max_memory, node, pair)
# build openfold # build openfold
chunk_size = 64 chunk_size = 64
openfold = _build_openfold() openfold = openfold_evoformer().cuda()
# benchmark # benchmark
_benchmark_evoformer(model, node, pair, "base") _benchmark_evoformer(model, node, pair, "base")
......
import torch
import torch.nn as nn
from .msa import MSAStack
from .ops import OutProductMean
from .triangle import PairStack
def print_memory(init_mem, text=None):
now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem
max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem
print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem))
torch.cuda.reset_peak_memory_stats()
class EvoformerBlock(nn.Module):
def __init__(self, d_node, d_pair):
super(EvoformerBlock, self).__init__()
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15)
self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=d_pair)
def forward(self, node, pair):
node = self.msa_stack(node, pair)
pair = pair + self.communication(node)
pair = self.pair_stack(pair)
return node, pair
class Evoformer(nn.Module):
def __init__(self, d_node, d_pair):
super(Evoformer, self).__init__()
self.blocks = nn.ModuleList()
for _ in range(1):
self.blocks.append(EvoformerBlock(d_node, d_pair))
def forward(self, node, pair):
for b in self.blocks:
node, pair = b(node, pair)
return node, pair
def evoformer_tiny():
return Evoformer(d_node=64, d_pair=32)
def evoformer_base():
return Evoformer(d_node=256, d_pair=128)
def evoformer_large():
return Evoformer(d_node=512, d_pair=256)
__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large']
import math
import numpy as np
import torch.nn as nn
def glorot_uniform_af(x, gain=1.0):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in, fan_out = x.shape[-2:]
if len(x.shape) > 2:
receptive_field_size = np.prod(x.shape[:-2])
fan_in *= receptive_field_size
fan_out *= receptive_field_size
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
nn.init.uniform_(x, -dev, dev)
return x
import torch
import torch.nn.functional as F
def bias_sigmod_ele(y, bias, z):
return torch.sigmoid(y + bias) * z
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
residual: torch.Tensor, prob: float) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
out = residual + out
return out
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
prob: float) -> torch.Tensor:
return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))
\ No newline at end of file
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn import LayerNorm
from .kernel import bias_dropout_add
from .ops import SelfAttention, Transition
class MSARowAttentionWithPairBias(nn.Module):
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
super(MSARowAttentionWithPairBias, self).__init__()
self.d_node = d_node
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernormM = LayerNorm(d_node)
self.layernormZ = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
def forward(self, M_raw, Z):
## Input projections
M = self.layernormM(M_raw)
Z = self.layernormZ(Z)
b = F.linear(Z, self.linear_b_weights)
b = b.permute(0, 3, 1, 2)
# b = rearrange(b, 'b q k h -> b h q k')
M = self.attention(M, b)
dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
class MSAColumnAttention(nn.Module):
def __init__(self, d_node, c=32, n_head=8):
super(MSAColumnAttention, self).__init__()
self.d_node = d_node
self.c = c
self.n_head = n_head
self.layernormM = LayerNorm(d_node)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True)
def forward(self, M_raw):
M = M_raw.transpose(-2, -3)
M = self.layernormM(M)
M = self.attention(M)
M = M.transpose(-2, -3)
return M_raw + M
class MSAStack(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15):
super(MSAStack, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
d_pair=d_pair,
p_drop=p_drop)
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
self.MSATransition = Transition(d=d_node)
def forward(self, node, pair):
node = self.MSARowAttentionWithPairBias(node, pair)
node = self.MSAColumnAttention(node)
node = self.MSATransition(node)
return node
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment