Unverified Commit cbeda830 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[Offload][feature] Add auto shard functionality to remove requirement of...

[Offload][feature] Add auto shard functionality to remove requirement of nn.Sequential models. (#695)

* auto wrap functionality

* lint and doc strings

* fix lint errors

* lint errors and version skips

* remove mypy checking and add conditional import

* another math.prod instance

* another import fix

* address comments

* lint errors

* address comments

* fix lint errors

* add placeholder nodes to tracker list
parent 7bdb9a7f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, List
import torch
import torch.fx
from torch.fx.node import Node
def _get_count(param_count: Dict, node_name: str) -> int:
"""Identify different mutations of a given node name."""
# TODO(anj): This is not very stable since it is possible that the name
# may not be in the same format. Is there another way to identify nodes
# in a graph?
if node_name in param_count:
return param_count[node_name]
elif node_name.split("_")[0] in param_count:
return param_count[node_name.split("_")[0]]
else:
raise RuntimeError(f"Unable to find match between param {param_count} and node {node_name}")
def _create_shard_to_param_count(param_count: Dict, node_name_to_shard_id: Dict) -> Dict:
"""Utility to create a map from shard id to param count using existing state."""
shard_to_param_count: Dict[int, int] = {}
for node_name in node_name_to_shard_id.keys():
try:
count = _get_count(param_count, node_name)
except RuntimeError:
continue
if node_name_to_shard_id[node_name] in shard_to_param_count:
shard_to_param_count[node_name_to_shard_id[node_name]] += count
else:
shard_to_param_count[node_name_to_shard_id[node_name]] = count
return shard_to_param_count
def _split_nodes(model: torch.nn.Module, shard_count: int = 3) -> Dict:
"""Utility used to trace a graph and identify shard cutpoints."""
node_name_to_shard_id: Dict[str, int] = {}
shard_id = 0
nodes_so_far = []
param_count: Dict[str, int] = {}
shard_to_param_count = {}
traced_graph_module = torch.fx.symbolic_trace(model)
# Find the total number of params in the model and
# the number of params per shard we are aiming for.
for name, module in model.named_modules():
if "." in name:
continue
param_count[name] = sum([x.numel() for x in module.parameters()])
logging.info(f"Total number of params are {param_count['']}")
per_shard_param = param_count[""] // shard_count
logging.info(f"Per shard param count {per_shard_param}")
for node in traced_graph_module.graph.nodes:
if node.op == "placeholder":
node_name_to_shard_id[node.name] = shard_id
nodes_so_far.append(node.name)
elif node.op in ["get_attr", "call_function", "call_method", "call_module"]:
min_shard_id = shard_id
min_node_name = ""
# For each of the args of a given node, find the arg that is not the
# last node we traversed. This is to help us find skip connections
# across shards.
for arg in node.args:
# If the node has args that are inputs to the forward function, they
# may not have explicit names.
if not hasattr(arg, "name"):
continue
if arg.name in node_name_to_shard_id and arg.name != nodes_so_far[-1]:
if node_name_to_shard_id[arg.name] < min_shard_id:
min_shard_id = node_name_to_shard_id[arg.name]
min_node_name = arg.name
# If there is an input that is not from the previous shard,
# we collapse all the shards in between to be part of 1 shard.
# and update the param count per shard accordingly.
if min_shard_id < shard_id:
for node_name in reversed(nodes_so_far):
node_name_to_shard_id[node_name] = min_shard_id
if node_name == min_node_name:
break
shard_id = min_shard_id
# TODO(anj-s): Find a way to raise an error early if this can cause OOM errors.
shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id)
# Update state that is tracking node -> shard id and shard id -> param count.
node_name_to_shard_id[node.name] = shard_id
nodes_so_far.append(node.name)
# TODO(anj): This could just be an update, we don't need to recreate the map.
shard_to_param_count = _create_shard_to_param_count(param_count, node_name_to_shard_id)
# If we have gone over the number of params per shard count that we want to
# achieve, we should add a new shard.
# The shard_id may not have been updated in the map if we are at a node that does not
# have params.
if shard_id in shard_to_param_count and shard_to_param_count[shard_id] > per_shard_param:
shard_id += 1
elif node.op == "output":
break
return node_name_to_shard_id
def shard_model(model: torch.nn.Module, shard_count: int = 3) -> List[torch.fx.GraphModule]:
"""Utility used to shard a model using torch.fx.
This function traces the model twice in an attempt to identify the
right cutpoints and then shard the model. In the first pass we calculate
the number of parameters as we are tracing the graph and mark nodes at
which we might want to create a new module. In the second pass we
modify the graph by inserting placeholders and output nodes to essentially
shard the graph.
We don't support skip connections between shards. This means that all
input and output is self contained within a given shard. A node from
shard 1 cannot be an input to a node from shard 3. We expect all inputs
to a given shard to be coming from the last node in the previous shard.
This means that we may not be able to shard models by the specified
`shard_count` mentioned by the user.
Args:
model (nn.Module): Model to be sharded as specified by the device count.
shard_count (int): Number of shards that we want to split the model into.
"""
module_list: List[torch.fx.GraphModule] = []
num_graphs = 0
new_graph = torch.fx.Graph() # type: ignore
env: Dict[str, Node] = {}
new_input_node = None
# This is the first pass where we attempt to get a map of where
# we need to insert placeholder and output nodes.
node_name_to_shard_id = _split_nodes(model, shard_count=shard_count)
traced_graph_module = torch.fx.symbolic_trace(model)
# dummy value which indicates that this is the first node.
prev_shard_id = 1000
prev_node = None
for node in traced_graph_module.graph.nodes:
# If the current node is in the next shard, we insert an output node.
# A new graph is created and a placeholder is added for the next shard.
if node.name in node_name_to_shard_id and prev_shard_id < node_name_to_shard_id[node.name]:
assert prev_node, "prev_node cannot be None"
with new_graph.inserting_after(prev_node):
new_graph.output(env[prev_node.name])
num_graphs += 1
module_list.append(torch.fx.GraphModule(model, new_graph))
new_graph = torch.fx.Graph()
node_name = "placeholder" + str(num_graphs)
pl_node = new_graph.create_node("placeholder", node_name)
env[node_name] = pl_node
new_input_node = pl_node
if new_input_node is not None:
# Account for a placeholder in the new graph.
node.args = (new_input_node,)
new_input_node = None
if node.op in ["placeholder", "get_attr", "call_function", "call_method", "call_module"]:
# Copy the nodes from the existing graph to the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
elif node.op == "output":
# If this is the last node, we should add an output
# node and add the last graph to the list.
assert prev_node, "prev_node cannot be None"
with new_graph.inserting_after(prev_node):
new_graph.output(env[prev_node.name])
module_list.append(torch.fx.GraphModule(model, new_graph))
break
prev_node = new_node
prev_shard_id = node_name_to_shard_id[node.name]
return module_list
...@@ -426,7 +426,7 @@ class OffloadModel(nn.Module): ...@@ -426,7 +426,7 @@ class OffloadModel(nn.Module):
def __init__( def __init__(
self, self,
model: nn.Sequential, model: Any,
device: torch.device, device: torch.device,
offload_device: torch.device = torch.device("cpu"), offload_device: torch.device = torch.device("cpu"),
num_slices: int = 3, num_slices: int = 3,
...@@ -440,7 +440,7 @@ class OffloadModel(nn.Module): ...@@ -440,7 +440,7 @@ class OffloadModel(nn.Module):
if not device: if not device:
raise TypeError("`device` argument to `OffloadModel` cannot be None.") raise TypeError("`device` argument to `OffloadModel` cannot be None.")
if not isinstance(model, nn.Sequential): if not (isinstance(model, nn.Sequential) or type(model) == list):
raise TypeError("`model` argument to `OffloadModel` must be of type `nn.Sequential`.") raise TypeError("`model` argument to `OffloadModel` must be of type `nn.Sequential`.")
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -448,20 +448,28 @@ class OffloadModel(nn.Module): ...@@ -448,20 +448,28 @@ class OffloadModel(nn.Module):
self.device = device self.device = device
self.offload_device = offload_device self.offload_device = offload_device
# Slice the model into roughly equivalent sequential shards.
splits = _split(model, num_slices)
# List of model shards that will be placed on/off the device. # List of model shards that will be placed on/off the device.
self.model_slices: List[nn.Module] = [] self.model_slices: List[nn.Module] = []
for i, split in enumerate(splits): # TODO(anj): Add an experimental flag for using this instead of modifying the
# Add one model handling this slice # arg type.
self.model_slices.append( if type(model) == list:
ModelShard( # This is already sharded using the auto shard functinality.
cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device, index=i, for i, m in enumerate(model):
self.model_slices.append(
ModelShard(cpu_model_shard=m, device=device, offload_device=offload_device, index=i,)
)
else:
# Slice the model into roughly equivalent sequential shards.
splits = _split(model, num_slices)
for i, split in enumerate(splits):
# Add one model handling this slice
self.model_slices.append(
ModelShard(
cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device, index=i,
)
) )
)
# Expose a unified view of the slices # Expose a unified view of the slices
self._model = torch.nn.Sequential(*self.model_slices) self._model = torch.nn.Sequential(*self.model_slices)
......
...@@ -55,6 +55,9 @@ warn_unused_ignores = true ...@@ -55,6 +55,9 @@ warn_unused_ignores = true
[mypy-fairscale.experimental.nn.distributed_pipeline.trace] [mypy-fairscale.experimental.nn.distributed_pipeline.trace]
ignore_errors = True ignore_errors = True
[mypy-fairscale.experimental.nn.auto_shard]
ignore_errors = True
[mypy-benchmarks.*] [mypy-benchmarks.*]
ignore_errors = True ignore_errors = True
......
...@@ -43,4 +43,5 @@ tests/experimental/nn/test_multiprocess_pipe.py ...@@ -43,4 +43,5 @@ tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_sync_batchnorm.py tests/experimental/nn/test_sync_batchnorm.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py tests/experimental/nn/test_offload.py
tests/experimental/nn/test_auto_shard.py
tests/experimental/optim/test_dynamic_loss_scaler.py tests/experimental/optim/test_dynamic_loss_scaler.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Testing Auto Shard functionality of non nn.Sequential models.
"""
import math
import pytest
import torch
import torch.nn
import torch.nn as nn
from fairscale.utils.testing import torch_version
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
self.d_model = d_model
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
# TODO(anj): Fix the following error when using autoshard
# Error: TypeError: slice indices must be integers or None or have an __index__ method
# x = x + self.pe[:x.size(0), self.d_model]
return self.dropout(x)
class TransformerModel(nn.Module):
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(TransformerModel, self).__init__()
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = torch.nn.TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = nn.Linear(ninp, ntoken)
self.init_weights()
def generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, *args):
src = args[0]
src_mask = args[1]
src = self.encoder(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
return output
bptt = 35
ntokens = 28783 # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 1 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
def test_single_run():
if torch_version() < (1, 8, 0):
pytest.skip("requires torch version >= 1.8.0")
from fairscale.experimental.nn.auto_shard import shard_model
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout)
sharded_model = shard_model(model)
assert len(sharded_model) == 2, "Length is sharded model is incorrect."
expected_param_nums = [5998600, 5785383]
for i, model in enumerate(sharded_model):
param_count = {}
for name, module in model.named_modules():
if "." in name:
continue
param_count[name] = sum([x.numel() for x in module.parameters()])
assert expected_param_nums[i] == param_count[""]
src_mask = torch.randn((35, 35), dtype=torch.float32)
src = torch.randint(1, ntokens, (35, 20))
input = [src, src_mask]
for model in sharded_model:
if type(input) == list:
input = model(*input)
else:
input = model(input)
assert input.size() == torch.Size([35, 20, 28783])
...@@ -15,7 +15,10 @@ import pytest ...@@ -15,7 +15,10 @@ import pytest
import torch import torch
from fairscale.experimental.nn.offload import OffloadModel from fairscale.experimental.nn.offload import OffloadModel
from fairscale.utils.testing import skip_if_no_cuda from fairscale.utils.testing import skip_if_no_cuda, torch_version
if torch_version() >= (1, 8, 0):
from fairscale.experimental.nn.auto_shard import shard_model
def _init(): def _init():
...@@ -135,7 +138,11 @@ def _train_offload_model( ...@@ -135,7 +138,11 @@ def _train_offload_model(
@pytest.mark.parametrize("use_fp16", [True, False]) @pytest.mark.parametrize("use_fp16", [True, False])
@pytest.mark.parametrize("checkpoint_activation", [True, False]) @pytest.mark.parametrize("checkpoint_activation", [True, False])
@pytest.mark.parametrize("num_microbatches", [1, 5]) @pytest.mark.parametrize("num_microbatches", [1, 5])
def test_correctness(use_fp16, checkpoint_activation, num_microbatches): @pytest.mark.parametrize("use_auto_shard", [True, False])
def test_correctness(use_fp16, checkpoint_activation, num_microbatches, use_auto_shard):
if use_auto_shard and torch_version() < (1, 8, 0):
pytest.skip("auto_shard requires torch version >= 1.8.0")
if (use_fp16 or checkpoint_activation) and not hasattr(torch.cuda.amp, "custom_fwd"): if (use_fp16 or checkpoint_activation) and not hasattr(torch.cuda.amp, "custom_fwd"):
pytest.skip(f"AMP APIs are not supported in torch version {torch.__version__}") pytest.skip(f"AMP APIs are not supported in torch version {torch.__version__}")
...@@ -144,9 +151,14 @@ def test_correctness(use_fp16, checkpoint_activation, num_microbatches): ...@@ -144,9 +151,14 @@ def test_correctness(use_fp16, checkpoint_activation, num_microbatches):
device, offload_device = _init() device, offload_device = _init()
model = _get_model() model = _get_model()
if use_auto_shard:
offload_model = shard_model(model)
else:
offload_model = model
rmodel, ropt, rloss = _train_reg_model(model, device, offload_device) rmodel, ropt, rloss = _train_reg_model(model, device, offload_device)
omodel, oopt, oloss = _train_offload_model( omodel, oopt, oloss = _train_offload_model(
model, offload_model,
device, device,
offload_device, offload_device,
use_fp16=use_fp16, use_fp16=use_fp16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment