Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
import os
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, Union
from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
from torch.nn.modules.module import _addindent
try:
from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
from torch.fx.graph import Graph, PythonCode, _PyTreeCodeGen
from torch.fx.graph_module import GraphModule, _exec_with_source, _forward_from_src, _WrappedCall
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
COLOGM = True
except:
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
COLOGM = False
if COLOGM:
class ColoGraphModule(GraphModule):
def __init__(self,
def __init__(
self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph,
class_name: str = 'GraphModule',
ckpt_codegen: bool = True):
class_name: str = "GraphModule",
ckpt_codegen: bool = True,
):
if ckpt_codegen:
graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name)
......@@ -60,7 +63,7 @@ if COLOGM:
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')
python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
# To split ckpt functions code and forward code
......@@ -83,7 +86,7 @@ if COLOGM:
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
if '_wrapped_call' not in vars(cls):
if "_wrapped_call" not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
......@@ -108,7 +111,7 @@ if COLOGM:
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
torch.save(self.state_dict(), folder / 'state_dict.pt')
torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4
# we add import colossalai here
......@@ -125,7 +128,13 @@ class {module_name}(torch.nn.Module):
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
safe_reprs = [
nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
]
if type(module) in safe_reprs:
return f"{module.__repr__()}"
......@@ -136,10 +145,10 @@ class {module_name}(torch.nn.Module):
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
module_file = folder / f'{module_name}.pt'
module_file = folder / f"{module_name}.pt"
torch.save(module, module_file)
blobified_modules.append(module_name)
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
......@@ -156,19 +165,20 @@ class {module_name}(torch.nn.Module):
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
module_file = folder / 'module.py'
module_file = folder / "module.py"
module_file.write_text(model_str)
init_file = folder / '__init__.py'
init_file.write_text('from .module import *')
init_file = folder / "__init__.py"
init_file.write_text("from .module import *")
if len(blobified_modules) > 0:
warnings.warn("Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}")
warnings.warn(
"Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}"
)
else:
class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = "GraphModule"):
super().__init__(root, graph, class_name)
import numpy as np
import torch
import tqdm
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
......@@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
accumulate_bwd_flop = 0
block_nodes = []
for node in gm.graph.nodes:
if 'block_split' in node.name:
if "block_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
accumulate_bwd_flop += node.bwd_flop
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
with gm.graph.inserting_after(node):
block_node = gm.graph.create_node('call_function', block_split)
setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
block_node = gm.graph.create_node("call_function", block_split)
setattr(block_node, "fwd_flop", accumulate_fwd_flop)
setattr(block_node, "bwd_flop", accumulate_bwd_flop)
accumulate_fwd_flop = 0
accumulate_bwd_flop = 0
block_nodes.append(block_node)
......@@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
def remove_blocks(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if (node.op, node.target) == ('call_function', block_split):
if (node.op, node.target) == ("call_function", block_split):
gm.graph.erase_node(node)
......@@ -55,8 +53,8 @@ def get_compute_costs(node_list):
num_nodes = len(node_list)
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
for start in tqdm.tqdm(range(num_nodes), desc="start pos", position=0):
for end in tqdm.tqdm(range(start, num_nodes), desc="end pos", position=1, leave=False):
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
all_compute_cost[start, end] = sum(selected_flops)
......@@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost
# record start node index for next stage in this partition
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
f[0, num_nodes] = 0
for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
for s in tqdm.tqdm(
range(1, num_stages + 1), desc="stage", position=2, leave=False
): # pylint: disable=too-many-nested-blocks
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc="start node", position=3, leave=False):
for k in tqdm.tqdm(range(num_nodes, i, -1), desc="mid node", position=4, leave=False):
stage_cost = compute_costs[i, k - 1]
new_cost = f[s - 1, k] + stage_cost
if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
if stage_cost <= max_compute_cost and new_cost < f[s, i]:
f[s, i] = new_cost
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
f_argmin[s, i] = k
......@@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
if max_compute_cost - last_max_compute_cost < gap:
continue
cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
max_compute_cost)
cost, solution = do_dp_split_gpipe_impl(
len(node_list), num_stages, num_microbatches, compute_costs, max_compute_cost
)
if cost < best_cost:
best_cost = cost
......@@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
# split_mode:
# 'node': fx_node
# 'block': many fx_nodes construct a block
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
assert mode in ['node', 'block']
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode="block", block_limit=0.01):
assert mode in ["node", "block"]
# nodes or blocks will be used in partition.
node_list = []
if mode == 'node':
if mode == "node":
for node in gm.graph.nodes:
node_list.append(node)
elif mode == 'block':
elif mode == "block":
node_list = construct_blocks(gm, limit=block_limit)
else:
pass
......@@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
for (_, next_start_node) in best_solution:
for _, next_start_node in best_solution:
if pp_size <= 1:
break
node = node_list[next_start_node]
with gm.graph.inserting_before(node):
split_node = gm.graph.create_node('call_function', pipe_split)
split_node = gm.graph.create_node("call_function", pipe_split)
pp_size -= 1
# remove block node if possible
if mode == 'block':
if mode == "block":
remove_blocks(gm)
gm.recompile()
......@@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
if 'tensor_meta' not in check_node.meta:
if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_fwd_flop = 0
......@@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
if 'pipe_split' in node.name:
if "pipe_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
if accumulate_fwd_flop >= partition_flop:
......@@ -199,14 +200,14 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_flop = total_fwd_flop // pp_size
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In avgnode_split_pass, simpliy split graph by node number.
In avgnode_split_pass, simply split graph by node number.
"""
mod_graph = gm.graph
avg_num_node = len(mod_graph.nodes) // pp_size
......@@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
if accumulate_num_node >= avg_num_node:
accumulate_num_node = 0
pp_size -= 1
if node.next.op == 'output':
if node.next.op == "output":
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
......@@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
# If the next node is output node, we will insert split annotation before
# node to make sure there is at least one node in last partition.
if node.next.op == 'output':
if node.next.op == "output":
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
if pp_size > 1:
node_counter = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if node.op == 'placeholder':
if node.op == "placeholder":
continue
elif node_counter == 0:
node_counter += 1
......@@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
node_counter = 0
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
......@@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
# To use balanced_split_pass_v2, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
if 'tensor_meta' not in check_node.meta:
if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_element_size = 0
......@@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
if 'pipe_split' in node.name:
if "pipe_split" in node.name:
continue
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
......@@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_size = total_element_size // pp_size
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
......@@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
accumulate_layer_amount = 0
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
......@@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
def split_callback(n: torch.fx.Node):
nonlocal part_idx
if (n.op, n.target) == ('call_function', pipe_split):
if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1
return part_idx
......@@ -355,7 +356,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
if (node.op, node.target) == ('call_function', pipe_split):
if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment