Unverified Commit 197d0bf4 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] apply repeat block to reduce solving time (#2912)

parent a8480911
...@@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc ...@@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
This method is used to solve the best solution for the given graph. This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node. The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
''' '''
graph_analyser = GraphAnalyser(gm) # temporarily we use all nodes as liveness list, we count the backward memory cost together with
liveness_list = graph_analyser.liveness_analysis() # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
# liveness_list = graph_analyser.liveness_analysis()
cost_graph = CostGraph(strategy_constructor.leaf_strategies) cost_graph = CostGraph(strategy_constructor.leaf_strategies)
cost_graph.simplify_graph() cost_graph.simplify_graph()
solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget) solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget)
ret = solver.call_solver_serialized_args() ret = solver.call_solver_serialized_args()
solution = list(ret[0]) solution = list(ret[0])
......
...@@ -32,7 +32,7 @@ class Solver: ...@@ -32,7 +32,7 @@ class Solver:
graph: Graph, graph: Graph,
strategies_constructor: StrategiesConstructor, strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph, cost_graph: CostGraph,
graph_analyser: GraphAnalyser, graph_analyser: GraphAnalyser = None,
memory_budget: float = -1.0, memory_budget: float = -1.0,
solution_numbers: int = 1, solution_numbers: int = 1,
forward_only: bool = False, forward_only: bool = False,
...@@ -63,7 +63,10 @@ class Solver: ...@@ -63,7 +63,10 @@ class Solver:
self.memory_increasing_coefficient = memory_increasing_coefficient self.memory_increasing_coefficient = memory_increasing_coefficient
else: else:
self.memory_increasing_coefficient = 1 self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis() # temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# self.liveness_list = self.graph_analyser.liveness_analysis()
self.liveness_list = self.nodes
self.node_index_dict = self._generate_node_index_dict() self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding. # The last solution vector of auto sharding.
self.last_s_val = None self.last_s_val = None
...@@ -140,7 +143,7 @@ class Solver: ...@@ -140,7 +143,7 @@ class Solver:
liveness_set = self.liveness_list liveness_set = self.liveness_list
# omit alias_set now # omit alias_set now
alias_set = None alias_set = self.strategies_constructor.alias_set
alias_convert_costs = None alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs # prepare compute_costs, communication_costs and memory_costs
...@@ -230,6 +233,7 @@ class Solver: ...@@ -230,6 +233,7 @@ class Solver:
# 0. Unpack flatten numpy arrays # 0. Unpack flatten numpy arrays
s_follow = following_nodes s_follow = following_nodes
s_alias = alias_set
E = edge_pairs.reshape((-1, 2)) # noqa E = edge_pairs.reshape((-1, 2)) # noqa
r = [] r = []
...@@ -294,8 +298,11 @@ class Solver: ...@@ -294,8 +298,11 @@ class Solver:
if strategies_len[i] == 1: if strategies_len[i] == 1:
s.append([1]) s.append([1])
else: else:
num_nodes += 1 if i not in s_alias:
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
s.append(s[s_alias[i]])
else: else:
if s_follow[i] < len(s): if s_follow[i] < len(s):
s.append(s[s_follow[i]]) s.append(s[s_follow[i]])
...@@ -311,15 +318,20 @@ class Solver: ...@@ -311,15 +318,20 @@ class Solver:
############################# #############################
e = [] e = []
num_edges = 0 num_edges = 0
map_edge_to_idx = {}
for (idx, (i, j)) in enumerate(E): for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1: if len(s[i]) == 1:
e.append(s[j]) e.append(s[j])
elif len(s[j]) == 1: elif len(s[j]) == 1:
e.append(s[i]) e.append(s[i])
else: else:
num_edges += 1 if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx:
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx]) assert len(e[idx]) == len(r[idx])
map_edge_to_idx[(i, j)] = idx
for element in s: for element in s:
assert len(element) > 0 assert len(element) > 0
# 2. Set initial value # 2. Set initial value
...@@ -371,13 +383,12 @@ class Solver: ...@@ -371,13 +383,12 @@ class Solver:
# compute memory consumption with liveness set # # compute memory consumption with liveness set #
################################################# #################################################
if memory_budget > 0: if memory_budget > 0:
for liveness_stage in liveness_set: mem = 0
mem = 0 for node in liveness_set:
for live_variable in liveness_stage.unique_live_vars: if node not in self.node_index_dict:
if live_variable.node not in self.node_index_dict: continue
continue node_index = self.node_index_dict[node]
node_index = self.node_index_dict[live_variable.node] mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget prob += mem <= memory_budget
# (d). specified by `cat="Binary"` # (d). specified by `cat="Binary"`
......
...@@ -15,6 +15,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import ( ...@@ -15,6 +15,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
) )
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from ..options import DataloaderOption, SolverOptions from ..options import DataloaderOption, SolverOptions
...@@ -42,6 +43,7 @@ class StrategiesConstructor: ...@@ -42,6 +43,7 @@ class StrategiesConstructor:
self.strategy_map = {} self.strategy_map = {}
self.solver_options = solver_options self.solver_options = solver_options
self.no_strategy_nodes = [] self.no_strategy_nodes = []
self.alias_set = None
def remove_duplicated_strategy(self, strategies_vector): def remove_duplicated_strategy(self, strategies_vector):
''' '''
...@@ -59,6 +61,22 @@ class StrategiesConstructor: ...@@ -59,6 +61,22 @@ class StrategiesConstructor:
for strategy in remove_list: for strategy in remove_list:
strategies_vector.remove(strategy) strategies_vector.remove(strategy)
def generate_alias_set(self):
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
repeat_block_nums = len(common_blocks)
alias_set = {}
if repeat_block_nums == 0:
return alias_set
for index, common_node in enumerate(common_blocks[0]):
for i in range(1, repeat_block_nums):
alias_set[node_list.index(common_blocks[i][index])] = node_list.index(common_node)
return alias_set
def build_strategies_and_cost(self): def build_strategies_and_cost(self):
""" """
This method is to build the strategy vector for each node in the computation graph. This method is to build the strategy vector for each node in the computation graph.
...@@ -175,3 +193,6 @@ class StrategiesConstructor: ...@@ -175,3 +193,6 @@ class StrategiesConstructor:
self.leaf_strategies.remove(node.strategies_vector) self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map: if node in self.strategy_map:
self.strategy_map.pop(node) self.strategy_map.pop(node)
alias_set = self.generate_alias_set()
self.alias_set = alias_set
...@@ -15,13 +15,13 @@ from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2 ...@@ -15,13 +15,13 @@ from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGTH = 32 SEQ_LENGTH = 32
HIDDEN_DIM = 768 HIDDEN_DIM = 384
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) @parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls): def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP: if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config) model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
else: else:
...@@ -54,15 +54,13 @@ def test_self_attention_block(model_cls): ...@@ -54,15 +54,13 @@ def test_self_attention_block(model_cls):
gm = GraphModule(model, graph, model.__class__.__name__) gm = GraphModule(model, graph, model.__class__.__name__)
print(gm.graph) print(gm.graph)
gm.recompile() gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions() solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost() strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph() cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1)
ret = solver.call_solver_serialized_args() ret = solver.call_solver_serialized_args()
strategies_list = solver.last_s_val strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
......
...@@ -9,7 +9,6 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_pre ...@@ -9,7 +9,6 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_pre
from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
...@@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, ...@@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
# solution construction # solution construction
cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph() cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm) solver = Solver(gm.graph, strategies_constructor, cost_graph, verbose=False)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
ret = solver.call_solver_serialized_args() ret = solver.call_solver_serialized_args()
solution = list(ret[0]) solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
......
...@@ -51,15 +51,14 @@ def test_cost_graph(): ...@@ -51,15 +51,14 @@ def test_cost_graph():
# return fc # return fc
gm = GraphModule(model, graph, model.__class__.__name__) gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile() gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions() solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost() strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph() cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) solver = Solver(gm.graph, strategies_constructor, cost_graph)
ret = solver.call_solver_serialized_args() ret = solver.call_solver_serialized_args()
print(ret[0]) print(ret[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment