Unverified Commit b231430b authored by Boyuan Yao's avatar Boyuan Yao Committed by GitHub
Browse files

[fx] Fix wrong index in annotation and minimal flops in ckpt solver (#1521)

* [fx] fix wrong variable name in solver rotor

* [fx] fix wrong variable name in solver rotor

* [fx] fix the discretize bug

* [fx] fix the first op in activation checkpoint codegen

* [fx] fix some bugs of ckpt solver

* [fx] modify test_ckpt_torchvision

* [fx] set sequence to __sequence__ attr of GraphModule

* [fx] docstring modification

* [fx] remove performance test
parent 07f5a4e0
from typing import List, Set, Tuple, Dict from typing import List, Set, Tuple, Dict
import torch import torch
from torch.fx import GraphModule, Node from torch.fx import GraphModule, Node
from colossalai.fx.graph_module import ColoGraphModule
import math import math
from .linearize import linearize from .linearize import linearize
from .utils import * from .utils import *
...@@ -131,10 +132,10 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i ...@@ -131,10 +132,10 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel * x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel *
torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size()) torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size())
for node in node_dict[key]: for node in node_dict[key]:
fwd_time[-1] += node.__flops__ fwd_time[-1] += max(node.__flops__, 1)
# currently we haven't patched the backward flops count # currently we haven't patched the backward flops count
bwd_time[-1] += node.__flops__ * 2 bwd_time[-1] += max(node.__flops__ * 2, 2)
xbar_sizes[-1] += node.__activation__ xbar_sizes[-1] += node.__activation__
...@@ -164,16 +165,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G ...@@ -164,16 +165,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
elif isinstance(op, ForwardEnable): elif isinstance(op, ForwardEnable):
in_ckpt = False in_ckpt = False
for idx in ckpt_region: for node_idx in ckpt_region:
for node in node_dict[idx]: for node in node_dict[node_idx]:
setattr(node, "activation_checkpoint", ckpt_idx) setattr(node, "activation_checkpoint", ckpt_idx)
ckpt_idx += 1 ckpt_idx += 1
ckpt_region = [] ckpt_region = []
elif isinstance(op, ForwardCheck): elif isinstance(op, ForwardCheck):
for idx in ckpt_region: for node_idx in ckpt_region:
for node in node_dict[idx]: for node in node_dict[node_idx]:
setattr(node, "activation_checkpoint", ckpt_idx) setattr(node, "activation_checkpoint", ckpt_idx)
ckpt_idx += 1 ckpt_idx += 1
...@@ -185,7 +186,19 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G ...@@ -185,7 +186,19 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
ckpt_region.append(idx) ckpt_region.append(idx)
def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> GraphModule: def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> ColoGraphModule:
"""solver that automatically find activation checkpoint in rotor's manner
Args:
gm (ColoGraphModule): ColoGraphModule generated by tracing model.
data (torch.Tensor): input data.
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
"""
node_dict = linearize(gm) node_dict = linearize(gm)
mem_unit = mem_limit // mem_slots mem_unit = mem_limit // mem_slots
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)
...@@ -193,4 +206,7 @@ def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: ...@@ -193,4 +206,7 @@ def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots:
opt_table = _compute_table(chain, mem_slots) opt_table = _compute_table(chain, mem_slots)
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
_annotate_from_sequence(sequence, node_dict) _annotate_from_sequence(sequence, node_dict)
# set __sequence__ attribute to GraphModule
setattr(gm, "__sequence__", sequence)
return gm return gm
...@@ -62,13 +62,13 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call ...@@ -62,13 +62,13 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
def _run_ckpt_solver(rank): def _run_ckpt_solver(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
MODEL_LIST = [tm.resnet18, tm.densenet121] MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
tracer = ColoTracer(trace_act_ckpt=False) tracer = ColoTracer(trace_act_ckpt=False)
data = torch.rand(2, 3, 32, 32, device='meta') data = torch.rand(8, 3, 224, 224, device='meta')
for solver in SOLVERS: for solver in SOLVERS:
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
m = model_cls(num_classes=5) m = model_cls(num_classes=5)
...@@ -95,13 +95,13 @@ def test_ckpt_solver(): ...@@ -95,13 +95,13 @@ def test_ckpt_solver():
def _run_ckpt_solver_torch11(rank): def _run_ckpt_solver_torch11(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
MODEL_LIST = [tm.resnet18, tm.densenet121] MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
tracer = ColoTracer(trace_act_ckpt=False) tracer = ColoTracer(trace_act_ckpt=False)
data = torch.rand(2, 3, 32, 32, device='meta') data = torch.rand(8, 3, 32, 32, device='meta')
for solver in SOLVERS: for solver in SOLVERS:
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
m = model_cls(num_classes=5) m = model_cls(num_classes=5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment