Unverified Commit 3345c6d3 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparellel]add strategies constructor (#1505)

* [autoparellel]add strategies constructor

* remove duplicated strategies

* polish code

* adapt cost graph with StrategiesConstructor

* polish
parent a0436a62
import torch
import operator
__all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP',
'LINEAR_FUNC_OP'
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d
]
CONV_FUNC_OP = [
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
]
LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
...@@ -494,3 +494,10 @@ class ConvHandler(OperatorHandler): ...@@ -494,3 +494,10 @@ class ConvHandler(OperatorHandler):
self.split_1d_parallel_on_in_channel(0, 1) self.split_1d_parallel_on_in_channel(0, 1)
return self.strategies_vector return self.strategies_vector
CONV_STRATEGIES_LIST = [
'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R',
'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1',
'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'
]
...@@ -39,11 +39,11 @@ class CostGraph: ...@@ -39,11 +39,11 @@ class CostGraph:
dst_node = strategies_vector.node dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes: for src_node in strategies_vector.predecessor_nodes:
node_pair = (src_node, dst_node) node_pair = (src_node, dst_node)
src_index = strategies_vector.predecessor_nodes.index(src_node) # src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {} edge_cost = {}
for i in range(len(strategies_vector)): for i in range(len(strategies_vector)):
for j in range(len(src_node.stategy_vector)): for j in range(len(src_node.strategies_vector)):
edge_cost[(i, j)] = strategies_vector[i].resharding_costs[src_index][j] edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j]
self.edge_costs[node_pair] = edge_cost self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node # add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes) setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
...@@ -83,26 +83,40 @@ class CostGraph: ...@@ -83,26 +83,40 @@ class CostGraph:
merge_map = {} merge_map = {}
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector): for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
resharding_costs = strategy.resharding_costs resharding_costs = strategy.resharding_costs
resharding_cost_for_src = resharding_costs[src_node_index] resharding_cost_for_src = resharding_costs[src_node]
lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src)) lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src))
merge_map[dst_strate_index] = lowest_cost_index merge_map[dst_strate_index] = lowest_cost_index
# extra_node_cost for dst node # extra_node_cost for dst node
extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])] self.extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])]
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector): for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
target_strate_index = merge_map[dst_strate_index] target_strate_index = merge_map[dst_strate_index]
extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node_index][ self.extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node][
target_strate_index] target_strate_index]
if src_node in extra_node_costs: if src_node in self.extra_node_costs:
extra_node_costs[dst_node][dst_strate_index] += extra_node_costs[src_node][target_strate_index] self.extra_node_costs[dst_node][dst_strate_index] += self.extra_node_costs[src_node][
target_strate_index]
# add new node pair to cost graph
for parent_node in src_node.parents:
new_node_pair = (parent_node, dst_node)
old_node_pair = (parent_node, src_node)
if new_node_pair in self.edge_costs:
continue
edge_cost = {}
for i in range(self.node_lens[dst_node]):
for j in range(self.node_lens[parent_node]):
src_strate_index = merge_map[i]
edge_cost[(j, i)] = self.edge_costs[old_node_pair][(j, src_strate_index)]
self.edge_costs[new_node_pair] = edge_cost
# connect dst node and parents of src node # connect dst node and parents of src node
dst_node.parents.remove(src_node) dst_node.parents.remove(src_node)
src_node.children.remove(dst_node) src_node.children.remove(dst_node)
node_pair_to_remove = [(src_node, dst_node)] self.edge_costs.pop((src_node, dst_node))
for parent_node in src_node.parents: for parent_node in src_node.parents:
if parent_node not in dst_node.parents: if parent_node not in dst_node.parents:
dst_node.parents.append(parent) dst_node.parents.append(parent_node)
if dst_node not in parent_node.children: if dst_node not in parent_node.children:
parent_node.children.append(dst_node) parent_node.children.append(dst_node)
# remove src node from cost graph when src node has no consumer. # remove src node from cost graph when src node has no consumer.
...@@ -111,19 +125,6 @@ class CostGraph: ...@@ -111,19 +125,6 @@ class CostGraph:
node_pair = (parent_node, src_node) node_pair = (parent_node, src_node)
self.edge_costs.pop(node_pair) self.edge_costs.pop(node_pair)
# add new node pair to cost graph
for parent_node in src_node.parents:
new_node_pair = (parent_node, dst_node)
old_node_pair = (parent_node, src_node)
if new_node_pair in self.edge_costs:
continue
edge_cost = {}
for i in range(self.node_lens[dst_node]):
for j in range(self.node_lens[parent_node]):
src_strate_index = merge_map[i]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(j, src_strate_index)]
self.edge_costs[new_node_pair] = edge_cost
def simplify_graph(self): def simplify_graph(self):
if not self.simplify: if not self.simplify:
return return
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch.fx.node import Node from torch.fx.node import Node
from typing import Dict from typing import Dict, List
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
...@@ -56,7 +56,7 @@ class OperatorHandler(ABC): ...@@ -56,7 +56,7 @@ class OperatorHandler(ABC):
""" """
pass pass
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec: def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
""" """
Generate the sharding spec of the tensor based on the given dim_partition_dict Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding. where the key is the tensor dimension and the value is the mesh dimension for sharding.
...@@ -84,7 +84,9 @@ class OperatorHandler(ABC): ...@@ -84,7 +84,9 @@ class OperatorHandler(ABC):
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input): for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
resharding_costs[input_node] = [] resharding_costs[input_node] = []
for strategy in input_node.strategies_vector: for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency( _, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
strategy.output_sharding_spec, input_spec) input_sharding_spec, input_spec)
resharding_costs[input_node].append(resharding_cost) resharding_costs[input_node].append(resharding_cost)
return resharding_costs return resharding_costs
from dataclasses import dataclass from dataclasses import dataclass
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from typing import Dict, List from typing import Dict, List, Union, Tuple
from torch.fx.node import Node from torch.fx.node import Node
from .constants import *
__all__ = ['ShardingStrategy', 'StrategiesVector'] __all__ = ['ShardingStrategy', 'StrategiesVector']
...@@ -25,12 +26,15 @@ class ShardingStrategy: ...@@ -25,12 +26,15 @@ class ShardingStrategy:
''' '''
name: str name: str
output_sharding_spec: ShardingSpec # TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
compute_cost: float = 0. compute_cost: float = 0.
communication_cost: float = 0. communication_cost: float = 0.
memory_cost: float = 0. memory_cost: float = 0.
resharding_costs: Dict[int, List[float]] = None resharding_costs: Dict[Node, List[float]] = None
input_shardings: ShardingSpec = None # sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
# Therefore, we could process them at the specific op(operator.getitem)
input_shardings: List[ShardingSpec] = None
class StrategiesVector(list): class StrategiesVector(list):
...@@ -46,8 +50,23 @@ class StrategiesVector(list): ...@@ -46,8 +50,23 @@ class StrategiesVector(list):
super().__init__() super().__init__()
self.node = node self.node = node
# fetch its input and output nodes # fetch its input and output nodes
# TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys()) self.predecessor_nodes = list(node._input_nodes.keys())
self.successor_nodes = list(node.users.keys()) self.successor_nodes = list(node.users.keys())
def check_merge(self): def check_merge(self):
pass merge_label = False
if self.node.op == 'call_module':
target = self.node.target
root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target)
submod_type = type(submod)
# merge elementwise module node into following nodes
if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True
if self.node.op == 'call_function':
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
return merge_label
This diff is collapsed.
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from copy import deepcopy
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
self.relu = nn.ReLU()
def forward(self, x):
x = x * 2
x = self.conv1(x)
x = x / 2
x = self.relu(x)
return x
def test_cost_graph():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv1, 2), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%truediv,), kwargs = {})
# return relu
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = {'fast_mode': True}
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor.build_strategies_and_cost()
# (x, mul): {(0, 0): 0}
# (mul, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0}
# (conv1, truediv): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): 0, (8, 0): inf, (9, 0): 0, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): inf, (5, 1): 0, (6, 1): 0, (7, 1): inf, (8, 1): 0, (9, 1): inf, (10, 1): 0, (11, 1): 0, (12, 1): 0, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): 0, (11, 2): 0, (12, 2): 0, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): 0, (11, 3): 0, (12, 3): 0, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): 0, (11, 4): 0, (12, 4): 0, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): 0, (11, 5): 0, (12, 5): 0, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): 0, (11, 7): 0, (12, 7): 0, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0, (9, 8): inf, (10, 8): 0, (11, 8): 0, (12, 8): 0, (13, 8): inf, (14, 8): 0}
# (truediv, relu): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): inf, (5, 0): 0, (6, 0): 0, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): 0, (5, 1): inf, (6, 1): 0, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): 0, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): 0, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): 0, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): 0, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): 0, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0}
# (relu, output): {(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002}
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
# construct all node pairs
all_node_pairs = []
for node in graph.nodes:
if node.op == 'output':
continue
all_node_pairs.append((node, node.next))
for node_pair in all_node_pairs:
assert node_pair in cost_graph.edge_costs
# construct merged node pairs
merged_node_pairs = []
node_list = list(graph.nodes)
# add (x, conv) and (conv, output) into check node pairs
merged_node_pairs.append((node_list[0], node_list[2]))
merged_node_pairs.append((node_list[2], node_list[-1]))
# (x, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0}
# (conv1, output): {(0, 0): inf, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf}
cost_graph.simplify_graph()
for node_pair in all_node_pairs:
if node_pair in merged_node_pairs:
assert node_pair in cost_graph.edge_costs
else:
assert node_pair not in cost_graph.edge_costs
if __name__ == '__main__':
test_cost_graph()
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from copy import deepcopy
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
def forward(self, x):
x = x * 2
x = self.conv(x)
return x
def test_strategies_constructor():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = {'fast_mode': True}
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
assert strategies_constructor.leaf_strategies == []
assert strategies_constructor.strategy_map == {}
strategies_constructor.build_strategies_and_cost()
# check leaf_strategies
# In fast mode, placeholder node only has replica strategy.
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]'
# Third node is conv.
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
for strategy in strategies_constructor.leaf_strategies[2]:
conv_check_list.remove(strategy.name)
assert len(conv_check_list) == 0
# In fast mode, output node only has replica strategy.
assert strategies_constructor.leaf_strategies[3][0].name == 'Replica Output'
# check strategy_map
nodes = [node for node in graph.nodes]
# In fast mode, placeholder node only has replica strategy.
x = nodes[0]
assert strategies_constructor.strategy_map[x][0].name == 'Replica Placeholder'
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
mul = nodes[1]
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]'
# Third node is conv.
conv = nodes[2]
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
for strategy in strategies_constructor.strategy_map[conv]:
conv_check_list.remove(strategy.name)
assert len(conv_check_list) == 0
# In fast mode, output node only has replica strategy.
output = nodes[3]
assert strategies_constructor.strategy_map[output][0].name == 'Replica Output'
if __name__ == '__main__':
test_strategies_constructor()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment