Unverified Commit 7bdf50a3 authored by Eugen Hotaj's avatar Eugen Hotaj Committed by GitHub
Browse files

Extend auto shard capabilities to work around torch.fx edge cases. (#817)

auto_shard.py currently uses torch.fx to create a symbolic DAG of
operations and linearizes that DAG into an nn.Sequential so it can later
be used for model offloading. This works in most cases but runs into
issues for certain eager mode features, such as dynamic conditionals,
shape-dependent computation, etc.

This PR extends auto_shard.py to first run a preprocessing step which wraps
any nn.Module which cannot be traced through. It adds a test for dynamic
conditionals and updates existing failing test code.

There are some immediate extensions to this approach which are marked as
TODO in the code.
parent e4da75ea
......@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, List
from typing import Dict, List, Set
import torch
import torch.fx
......@@ -40,7 +40,7 @@ def _create_shard_to_param_count(param_count: Dict, node_name_to_shard_id: Dict)
return shard_to_param_count
def _split_nodes(model: torch.nn.Module, shard_count: int = 3) -> Dict:
def _split_nodes(traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict:
"""Utility used to trace a graph and identify shard cutpoints."""
node_name_to_shard_id: Dict[str, int] = {}
......@@ -49,15 +49,12 @@ def _split_nodes(model: torch.nn.Module, shard_count: int = 3) -> Dict:
param_count: Dict[str, int] = {}
shard_to_param_count = {}
traced_graph_module = torch.fx.symbolic_trace(model)
# Find the total number of params in the model and
# the number of params per shard we are aiming for.
for name, module in model.named_modules():
for name, module in traced_graph_module.named_modules():
if "." in name:
continue
param_count[name] = sum([x.numel() for x in module.parameters()])
logging.info(f"Total number of params are {param_count['']}")
per_shard_param = param_count[""] // shard_count
logging.info(f"Per shard param count {per_shard_param}")
......@@ -112,6 +109,46 @@ def _split_nodes(model: torch.nn.Module, shard_count: int = 3) -> Dict:
return node_name_to_shard_id
class _ExtendedLeafTracer(torch.fx.Tracer):
"""Tracer with an extended set of leaf nn.Modules."""
def __init__(self, leaf_modules: Set[torch.nn.Module]):
"""Initializes a new _ExtendedLeafTracer object.
Args:
leaf_modules: The set of extra nn.Modules instances which will not be traced
through but instead considered to be leaves.
"""
super().__init__()
self.leaf_modules = leaf_modules
def is_leaf_module(self, m: torch.nn.Module, model_qualified_name: str) -> bool:
return super().is_leaf_module(m, model_qualified_name) or m in self.leaf_modules
# TODO(ehotaj): Extend this method to wrap at the least granular level. One way to do
# would be to wrap the Module tree bottom up, first wrapping untracable children and
# only wrapping parents if they are also untracable.
def _trace(model: torch.nn.Module) -> torch.fx.GraphModule:
"""Traces the given model and automatically wraps untracable modules into leaves."""
leaf_modules = set()
tracer = _ExtendedLeafTracer(leaf_modules)
for name, module in model.named_modules():
# TODO(ehotaj): The default is_leaf_module includes everything in torch.nn.
# This means that some coarse modules like nn.TransformerEncoder are treated
# as leaves, not traced, and are unable to be sharded. We may want to extend our
# sharding code to trace through these modules as well.
if tracer.is_leaf_module(module, ""):
continue
try:
tracer.trace(module)
except (TypeError, torch.fx.proxy.TraceError):
leaf_modules.add(module)
tracer = _ExtendedLeafTracer(leaf_modules)
graph = tracer.trace(model)
return torch.fx.GraphModule(model, graph)
def shard_model(model: torch.nn.Module, shard_count: int = 3) -> List[torch.fx.GraphModule]:
"""Utility used to shard a model using torch.fx.
......@@ -140,11 +177,12 @@ def shard_model(model: torch.nn.Module, shard_count: int = 3) -> List[torch.fx.G
new_graph = torch.fx.Graph() # type: ignore
env: Dict[str, Node] = {}
new_input_node = None
traced_graph_module = _trace(model)
# This is the first pass where we attempt to get a map of where
# we need to insert placeholder and output nodes.
node_name_to_shard_id = _split_nodes(model, shard_count=shard_count)
traced_graph_module = torch.fx.symbolic_trace(model)
node_name_to_shard_id = _split_nodes(traced_graph_module, shard_count=shard_count)
# dummy value which indicates that this is the first node.
prev_shard_id = 1000
......
......@@ -31,9 +31,7 @@ class PositionalEncoding(nn.Module):
self.register_buffer("pe", pe)
def forward(self, x):
# TODO(anj): Fix the following error when using autoshard
# Error: TypeError: slice indices must be integers or None or have an __index__ method
# x = x + self.pe[:x.size(0), self.d_model]
x = x + self.pe[: x.size(0)]
return self.dropout(x)
......@@ -107,3 +105,48 @@ def test_single_run():
input = model(input)
assert input.size() == torch.Size([35, 20, 28783])
class Branch(torch.nn.Module):
def __init__(self, features: int):
super().__init__()
self.left = nn.Linear(in_features=features, out_features=features)
self.right = nn.Linear(in_features=features, out_features=features)
def forward(self, x):
if x.sum() > 1000:
return self.left(x)
else:
return self.right(x)
class BranchedNetwork(torch.nn.Module):
def __init__(self, features: int):
super().__init__()
self.net = torch.nn.ModuleList([Branch(features) for _ in range(10)])
def forward(self, x):
for module in self.net:
x = module(x)
return x
def test_dynaimc_conditionals_auto_wrapped():
if torch_version() < (1, 8, 0):
pytest.skip("requires torch version >= 1.8.0")
from fairscale.experimental.nn.auto_shard import shard_model
features = 10
model = BranchedNetwork(features)
sharded_model = shard_model(model, 3)
# TODO(ehotaj): There might be a bug in our split code because we shard the
# model into 10 shards even though we specify 3 shards above.
assert len(sharded_model) == 10
input_ = torch.randn(3, features)
model_output = model(input_)
sharded_model_output = input_
for shard in sharded_model:
sharded_model_output = shard(sharded_model_output)
assert torch.allclose(model_output, sharded_model_output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment