Unverified Commit bdc0581b authored by Mehdi Mirzazadeh's avatar Mehdi Mirzazadeh Committed by GitHub
Browse files

adding auto graph generation for distributed pipeline (#615)

* adding auto graph generation for distributed pipeline

* ignore trace.py for my for now, since it needs pytorch 1.8

* fixing tests

* simplifying graph api

* remove unused debug utilities

* use inspect to find argument lists

* use sharded linear layer

* flkae8

* comment

* polishing

* polishing
parent 2bb2a134
......@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import List, Optional, Set, Tuple
from typing import List, Optional, Set, Tuple, Union
from torch import Tensor, nn
from torch.distributed import rpc
......@@ -78,46 +78,49 @@ class PipelineModulesGraph(nn.Module):
self.nodes.append(new_node)
return new_node
def _set_inputs(self, module: RemoteModule, inputs: List[DataSource]) -> None:
self._find_or_add(module).inputs = inputs
def add_sequence(self, modules: List[RemoteModule], first_input: Optional[RemoteModule] = None) -> None:
"""Adds a list of modules to the graph, to be run sequentially.
The connection between these modules is as follows: the first output of each of these modules
(except the last one) is used as the first input of its next module in this sequence.
The user may also specify the input to the first module in this sequence with argument 'first_input'.
In this case the module 'first_input' must have been added to the graph previously.
# DataSourceSpec lists choices the user has for specifying the source of each input to a module:
# -- If the input is one of model inputs, it is specified by a simple integer, which is the index of that input
# -- If the input comes from a module with a simple output, it is specified by that module
# -- If the input comes from a module with multiple outputs (a tuple), it is specified by that module and the
# index of the output
DataSourceSpec = Union[int, RemoteModule, Tuple[RemoteModule, int]]
def _data_source_spec_to_data_source(self, spec: DataSourceSpec) -> DataSource:
if isinstance(spec, int):
return DataSource(None, spec)
if isinstance(spec, RemoteModule):
return DataSource(self._find_node(spec), 0)
return DataSource(self._find_node(spec[0]), spec[1])
def add_layer(self, module: RemoteModule, inputs: List[DataSourceSpec], num_outputs: Optional[int] = None) -> None:
"""Adds a module with specified inputs to the graph. The modules that provide inputs to this module must have
been added previously to the graph and are listed with argument inputs. If the module output is a tuple,
num_outputs specifies the number of elements in the tuple.
"""
old_modules_len = len(self.nodes)
new_modules_len = len(modules)
self.nodes.extend(Node(mod) for mod in modules)
# update inputs array
if first_input is not None:
self.nodes[old_modules_len].inputs = [DataSource(self._find_node(first_input), 0)]
for i in range(old_modules_len + 1, old_modules_len + new_modules_len):
self.nodes[i].inputs = [DataSource(self.nodes[i - 1], 0)]
def set_model_input(self, module: RemoteModule, ind: int = 0) -> None:
"""Declares the input to a module as the input to the model. In case the model has multiple
inputs, the argument 'ind' indicates the index of the model input that is fed to the module.
"""
self._set_inputs(module, [DataSource(None, ind)])
def add_multi_input_layer(self, module: RemoteModule, inputs: List[RemoteModule]) -> None:
"""Adds a module with multiple inputs to the graph. The modules that provide inputs to this module
must have been added previously to the graph and are listed with argument inputs.
"""
self._set_inputs(module, [DataSource(self._find_node(m), 0) for m in inputs])
def fan_out(self, module: RemoteModule, outputs: List[RemoteModule]) -> None:
"""Feeds outputs of a previously added module to modules specified by argument 'outputs' (so
'module' should have at least 'len(outputs)' outputs.
Modules in the list 'outputs' are added to the graph if they have not been added previously.
node = Node(module)
node.inputs = [self._data_source_spec_to_data_source(spec) for spec in inputs]
node.num_outputs = num_outputs
self.nodes.append(node)
def add_sequence(
self,
modules: List[RemoteModule],
first_module_inputs: List[DataSourceSpec],
last_module_num_outputs: Optional[int] = None,
) -> None:
"""Adds a list of modules to the graph, to be run sequentially.
The connection between these modules is as follows: the output of each of these modules
(except the last one) is used as the input of its next module in this sequence.
So all modules (except the last one) must have simple output, and also all of them (except the first one)
should have a single input.
The user also specifies the input to the first module in this sequence with argument 'first_module_inputs'.
In case the last module output is a tuple, 'last_module_num_outputs' specifies the number of elements
in the tuple.
"""
node = self._find_node(module)
node.num_outputs = len(outputs)
for i, m in enumerate(outputs):
self._set_inputs(m, [DataSource(node, i)])
next_input = first_module_inputs
for i, module in enumerate(modules):
self.add_layer(module, next_input, last_module_num_outputs if i == len(modules) - 1 else None)
next_input = [module]
def _compile(self) -> None:
"""Precomputes self.model_input_consumers and self.output_consumers for internal use by the pipleine
......
import inspect
import operator
from typing import Dict, List, Optional, Tuple, cast
from torch.distributed.nn import RemoteModule
import torch.fx
import torch.nn as nn
from . import PipelineModulesGraph
class RemoteModuleTracer(torch.fx.Tracer):
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
if isinstance(m, RemoteModule):
return True
return False
class GraphCreator:
def __init__(self, tracer: RemoteModuleTracer) -> None:
self.tracer = tracer
def get_module(self, node: torch.fx.Node) -> Optional[nn.Module]:
"""Given a call_module node, returns the module corresponding to this module call"""
if node.op != "call_module":
return None
module = self.tracer.root
for t in cast(str, node.target).split("."):
module = getattr(module, t)
return module
def create_graph(self) -> PipelineModulesGraph:
# node_to_data maps nodes to the data they represent
node_to_data: Dict[torch.fx.Node, PipelineModulesGraph.DataSourceSpec] = {}
remote_module_nodes: List[Tuple[torch.fx.Node, RemoteModule]] = []
for node in self.tracer.graph.nodes:
if node.op == "call_module":
module = self.get_module(node)
assert isinstance(module, RemoteModule)
node_to_data[node] = module
remote_module_nodes.append((node, module))
elif node.target == operator.__getitem__ and node.op == "call_function":
assert node.args[0] in node_to_data
d = node_to_data[node.args[0]]
assert isinstance(d, RemoteModule)
node_to_data[node] = (d, node.args[1])
elif node.op == "placeholder":
arg_names = list(inspect.signature(self.tracer.root.forward).parameters)
node_to_data[node] = arg_names.index(node.target)
elif node.op == "output":
pass
else:
assert False, "Invalid node %s" % node
# Following dict stores cardinality of the output for modules: for each module, it stores None if the output
# is a simple tensor, and it stores the number of tensors in the the output if it is a tuple.
module_to_num_outputs: Dict[nn.Module, Optional[int]] = {}
for node, _ in remote_module_nodes:
# iterate over inputs to the module.
for arg in node.args:
data = node_to_data[arg]
if isinstance(data, int):
continue
if isinstance(data, RemoteModule):
assert module_to_num_outputs.get(data, None) is None
module_to_num_outputs[data] = None
else:
module, output_num = data
# Here we discovered that the output number "output_num" is used,
# so the number of outputs should be at least "output_num+1".
if module in module_to_num_outputs:
prev_value = module_to_num_outputs[module]
assert prev_value is not None
module_to_num_outputs[module] = max(prev_value, output_num + 1)
else:
module_to_num_outputs[module] = output_num + 1
graph = PipelineModulesGraph()
for node, module in remote_module_nodes:
inputs = [node_to_data[arg] for arg in node.args]
graph.add_layer(module, inputs, module_to_num_outputs.get(module))
return graph
def _call_trace(tracer: RemoteModuleTracer, module: nn.Module) -> torch.fx.Graph:
try:
org_named_modules = RemoteModule.named_modules
org_named_children = RemoteModule.named_children
RemoteModule.named_modules = nn.Module.named_modules # type: ignore
RemoteModule.named_children = nn.Module.named_children # type: ignore
return tracer.trace(module)
finally:
RemoteModule.named_modules = org_named_modules # type: ignore
RemoteModule.named_children = org_named_children # type: ignore
def make_graph(module: nn.Module) -> PipelineModulesGraph:
"""
Creates a PipelineModulesGraph for the module. The module should be traceable by torch.fx.
Also all operators on tensors should be done by RemoteModule's.
"""
tracer = RemoteModuleTracer()
r = _call_trace(tracer, module)
g = torch.fx.GraphModule(module, r)
return GraphCreator(tracer).create_graph()
......@@ -52,6 +52,9 @@ disallow_untyped_decorators = true
disallow_incomplete_defs = true
warn_unused_ignores = true
[mypy-fairscale.experimental.nn.distributed_pipeline.trace]
ignore_errors = True
[mypy-benchmarks.*]
ignore_errors = True
......
......@@ -76,8 +76,7 @@ def create_sequence_pipeline(
index = next_index
graph = PipelineModulesGraph()
graph.add_sequence(remote_modules)
graph.set_model_input(remote_modules[0])
graph.add_sequence(remote_modules, [0])
return DistributedPipeline(graph, **kwargs)
......@@ -189,7 +188,6 @@ def update(devices):
x = torch.randn(8, 4).to(device)
model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=4, devices=devices[:2])
params = pipe.parameter_rrefs()
opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,)
losses = []
for i in range(2):
......@@ -238,14 +236,63 @@ def multi_input_multi_output_layers(devices):
concatenate = RemoteModule(devices[1], ConcatenateTensors, ())
graph = PipelineModulesGraph()
graph.add_sequence([linear_layer_1, split])
graph.set_model_input(linear_layer_1)
graph.fan_out(split, linear_layers_2)
graph.add_multi_input_layer(concatenate, linear_layers_2)
graph.add_sequence([linear_layer_1, split], [0], 2)
for i, l in enumerate(linear_layers_2):
graph.add_layer(l, [(split, i)])
graph.add_layer(concatenate, linear_layers_2)
pipe = DistributedPipeline(graph, chunks=4)
assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe)
params = pipe.parameter_rrefs()
opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,)
losses = []
for i in range(2):
with dist_autograd.context() as context_id:
y = pipe(x)
loss = criterion(y, rpc.RRef(x))
losses.append(loss)
loss.backward(context_id)
opt.step(context_id)
losses = [l.to_here() for l in losses]
assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
# A test for extracting the same graph as in test multi_input_multi_output_layers automatically
class ShardedLinearLayer(nn.Module):
def __init__(self, input_device, shard_devices, output_device):
super().__init__()
self.split = RemoteModule(input_device, SplitTensors, (), {})
self.linear_layers_2 = nn.ModuleList(
[
RemoteModule(shard_devices[0], nn.Linear, (2, 2), {}),
RemoteModule(shard_devices[1], nn.Linear, (2, 2), {}),
]
)
self.concatenate = RemoteModule(output_device, ConcatenateTensors, ())
def forward(self, input):
shards = self.split(input)
shards = [self.linear_layers_2[i](shards[i]) for i in range(2)]
return self.concatenate(*shards)
@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def auto_graph_extract(devices):
from fairscale.experimental.nn.distributed_pipeline.trace import make_graph
device = devices[0].split("/")[1]
torch.random.manual_seed(3)
criterion = DistributedLoss(torch.nn.MSELoss)
x = torch.randn(8, 4).to(device)
# create model
model = nn.Sequential(
RemoteModule(devices[0], nn.Linear, (4, 4), {}), ShardedLinearLayer(devices[0], devices, devices[1])
)
graph = make_graph(model)
pipe = DistributedPipeline(graph, chunks=4)
partitions = extract_partitions(graph, pipe)
assert [[0, 1], [2], [3], [4]] == partitions, f"partitions={partitions}"
opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,)
losses = []
for i in range(2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment