Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
from typing import Optional
import torch
from colossalai.gemini.chunk import init_chunk_manager
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import MemStats
from .data_parallel import ZeroDDP
class GeminiDDP(ZeroDDP):
def __init__(self,
module: torch.nn.Module,
device: torch.device,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: Optional[float] = None,
memstats: Optional[MemStats] = None) -> None:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
from collections import OrderedDict
from copy import copy
from typing import Optional, Set
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import Chunk
from colossalai.utils import get_current_device
def get_temp_total_chunk_on_cuda(chunk: Chunk):
if chunk.is_gathered:
return chunk.chunk_total
return chunk.cuda_global_chunk
if chunk.cuda_shard is not None:
shard_temp = chunk.cuda_shard
......@@ -18,3 +24,89 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
return total_temp
def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''):
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules.
"""
if memo is None:
memo = set()
if module not in memo:
for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in _get_dfs_module_list(submodule, memo, submodule_prefix):
yield m
memo.add(module)
yield prefix, module
def _get_shallow_copy_model(model: nn.Module):
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
But the new submodule and the old submodule share all attributes.
"""
old_to_new = dict()
for name, module in _get_dfs_module_list(model):
new_module = copy(module)
new_module._modules = OrderedDict()
for subname, submodule in module._modules.items():
if submodule is None:
continue
setattr(new_module, subname, old_to_new[submodule])
old_to_new[module] = new_module
return old_to_new[model]
def get_static_torch_model(zero_ddp_model,
device=torch.device("cpu"),
dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module:
"""Get a static torch.nn.Module model from the given ZeroDDP module.
You should notice that the original ZeroDDP model is not modified.
Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors.
Args:
zero_ddp_model (ZeroDDP): a zero ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the coverted torch model
Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
from colossalai.nn.parallel import ZeroDDP
assert isinstance(zero_ddp_model, ZeroDDP)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False)
colo_model = zero_ddp_model.module
torch_model = _get_shallow_copy_model(colo_model)
if not only_rank_0 or dist.get_rank() == 0:
# record the mapping relationship between colo parameters and torch parameters
colo_to_torch = dict()
for (name, colo_module), (_, torch_module) in \
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
# clean the parameter list of the new torch module
torch_module._parameters = OrderedDict()
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
# get the full name of the parameter
full_param_name = name + ('.' if name else '') + sufix_param_name
if full_param_name not in state_dict:
# this means the parameter is shared by multiple modules
# we should use colo_to_torch to get the torch parameter created before
assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module"
torch_param = colo_to_torch[param]
else:
# we meet the parameter the first time, just use the state dict to get the data
state_param = state_dict[full_param_name]
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
colo_to_torch[param] = torch_param
setattr(torch_module, sufix_param_name, torch_param)
dist.barrier()
return torch_model
from .pipelinable import PipelinableContext, PipelinableModel
from .layer_sepc import LayerSpec
from .layer_spec import LayerSpec
__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
\ No newline at end of file
from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal
__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']
\ No newline at end of file
from .fx import get_topology as get_fx_topology
__all__ = ['get_fx_topology']
\ No newline at end of file
from torch.fx.graph_module import GraphModule
from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
import torch
def partition_name_to_id(partition_name, is_input=False, is_output=False):
if is_input:
partition_id = 0
elif is_output:
partition_id = 1
else:
prefix = 'submod_'
partition_id = int(partition_name.split(prefix)[-1]) + 2
return partition_id
# There are two kinds of def in fx.graph
# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value.
# e.g. submod1 = call_module(...)
# temporary_val = submod1[0]
# submod2 = call_module(temporary_val, ...)
# 2. direct_use & direct_def, which means the output is used by next partition directly.
# e.g. submod1 = call_module(...)
# submod2 = call_module(submod1, ...)
def find_input_in_partition(node, partitions, input_partitions=None):
p_input_val = None
direct_def = not node.name.startswith('getitem')
# search in input
if direct_def and input_partitions is not None:
partition_id = partition_name_to_id('', is_input=True)
for i, input_node in enumerate(input_partitions):
if input_node == node:
p_input_val = PartitionInputVal(partition_id=partition_id, offset=i)
return p_input_val
# search submod in mid part
if direct_def:
for partition in partitions:
if partition == node:
partition_id = partition_name_to_id(partition.name)
p_input_val = PartitionInputVal(partition_id=partition_id, offset=0)
return p_input_val
# search temporary value in graph
else:
for partition in partitions:
for offset, mid_val in enumerate(partition.users):
if mid_val == node:
partition_id = partition_name_to_id(partition.name)
p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset)
return p_input_val
return p_input_val
def find_output_in_partition(node, partitions, output_partitions=None):
p_output_val = PartitionOutputVal()
for user in node.users:
direct_use = not user.name.startswith('getitem')
# user is mid partition
for partition in partitions:
# direct call
if direct_use:
if user == partition:
partition_id = partition_name_to_id(partition.name)
for i, arg in enumerate(partition.args):
if arg == node:
p_output_val.add(partition_id=partition_id, offset=i)
break
# getitem call
else:
if user in partition.args:
partition_id = partition_name_to_id(partition.name)
for i, arg in enumerate(partition.args):
if arg == user:
p_output_val.add(partition_id=partition_id, offset=i)
break
# user is output
if output_partitions is not None:
output_node = output_partitions[0]
if user.op == output_node.op:
output_keys = {}
partition_id = partition_name_to_id('', is_output=True)
torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n))
for i, arg in enumerate(output_keys):
if arg == node:
p_output_val.add(partition_id=partition_id, offset=i)
break
return p_output_val
def get_topology(gm: GraphModule):
topo = Topo()
topo_output_partition = Partition()
input_partitions = []
partitions = []
output_partitions = []
for node in gm.graph.nodes:
if node.op == 'placeholder':
input_partitions.append(node)
elif node.name.startswith('submod_'):
partitions.append(node)
elif node.op == 'output':
output_partitions.append(node)
else:
continue
# set output for input_partition
topo_input_partition = Partition()
for partition in input_partitions:
cur_node = partition
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_input_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=0, partition=topo_input_partition)
topo.set_input_partition_id(partition_id=0)
for i, partition in enumerate(partitions):
topo_mid_partition = Partition()
# set input for submodule
for arg in partition.args:
cur_node = arg
p_input_val = find_input_in_partition(cur_node, partitions, input_partitions)
topo_mid_partition.add_input_val(p_input_val)
# set output for submodule
direct_use = True
for user in partition.users:
if user.name.startswith('getitem'):
direct_use = False
break
if direct_use:
cur_node = partition
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_mid_partition.add_output_val(p_output_val)
else:
for user in partition.users:
cur_node = user
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_mid_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=i+2, partition=topo_mid_partition)
# set input for output_partition
for partition in output_partitions:
topo_output_partition = Partition()
torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
find_input_in_partition(n, partitions, input_partitions)))
topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition_id(partition_id=1)
return topo
\ No newline at end of file
from typing import Dict, List
from dataclasses import dataclass
# This file includes data structure used by Pipeline Middleware.
@dataclass
class ValPosition:
partition_id: int
offset: int
def __str__(self) -> str:
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
return res
def __repr__(self) -> str:
return self.__str__()
class PartitionInputVal(object):
def __init__(self, partition_id, offset) -> None:
# every input from which partition_id and which offset
val_pos = ValPosition(partition_id, offset)
self._from_partition_and_offset: ValPosition = val_pos
def get(self):
return self._from_partition_and_offset
def __str__(self) -> str:
res = ''
res += f'<-({self._from_partition_and_offset})'
return res
def __repr__(self) -> str:
return self.__str__()
class PartitionOutputVal(object):
def __init__(self) -> None:
# every output to which partition_id and which offset
self._to_partition_and_offset: List[ValPosition] = []
def add(self, partition_id, offset):
val_pos = ValPosition(partition_id, offset)
self._to_partition_and_offset.append(val_pos)
def get(self):
return self._to_partition_and_offset
def __str__(self) -> str:
res = ''
res += '->('
for val_pos in self._to_partition_and_offset:
res += f'{val_pos},'
res += ')'
return res
def __repr__(self) -> str:
return self.__str__()
class Partition(object):
def __init__(self) -> None:
self._input_vals: List[PartitionInputVal] = []
self._output_vals: List[PartitionOutputVal] = []
def add_input_val(self, input_val: PartitionInputVal):
self._input_vals.append(input_val)
def add_output_val(self, output_val: PartitionOutputVal):
self._output_vals.append(output_val)
def get_input_vals(self):
return self._input_vals
def get_output_vals(self):
return self._output_vals
# get the output offsets sent to dst_partition_id
def get_output_offsets(self, dst_partition_id):
res = []
for offset, output_val in enumerate(self._output_vals):
outputs = output_val.get()
for val_pos in outputs:
if val_pos.partition_id == dst_partition_id:
res.append(offset)
return res
# get all input dst partition_ids
def get_input_partition_ids(self):
res = []
for input_val in self._input_vals:
val_pos = input_val.get()
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
# get all output dst partition_ids
def get_output_partition_ids(self):
res = []
for output_val in self._output_vals:
outputs = output_val.get()
for val_pos in outputs:
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
def __str__(self) -> str:
res = ''
res += f' input:\n'
res += f' length:{len(self._input_vals)}\n'
for i, input_val in enumerate(self._input_vals):
res += f' offset={i}:{input_val}\n'
res += f' output:\n'
res += f' length:{len(self._output_vals)}\n'
for i, output_val in enumerate(self._output_vals):
res += f' offset={i}:{output_val}\n'
return res
def __repr__(self) -> str:
return self.__str__()
# This class is a middleware between partition splitter
# and Pipeline Scheduler. It records the graph info about
# partition input/output and provides it to scheduler.
# There are three kinds of partition in Pipeline Middleware Design
# which represents the whole process of a model execution: input-fwd-output
# 1. input_partition: records the input of a model.
# 2. mid_partition: record the splitted forwards execution of a model.
# 3. output_partition: records the output of a model.
# attributes:
# _partitions: include all partitions
# _input_partition_id: the key represents input_partition
# _output_partition_id: the key represents output_partition
class Topo(object):
def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
self._partitions: Dict[int, Partition] = {}
self._input_partition_id = input_partition_id
self._output_partition_id = output_partition_id
def set_input_partition_id(self, partition_id: int):
self._input_partition_id = partition_id
def set_output_partition_id(self, partition_id: int):
self._output_partition_id = partition_id
def get_input_partition_id(self):
return self._input_partition_id
def get_output_partition_id(self):
return self._output_partition_id
def set_partitions(self, partition_id: int, partition: Partition):
self._partitions[partition_id] = partition
def get_mid_partitions(self):
res = {} #{partition_id: Partition}
for partition_id, partition in self._partitions.items():
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
continue
res[partition_id] = partition
return res
def get_mid_partition_ids(self):
return list(self.get_mid_partitions().keys())
def get_input_partition(self):
if self._input_partition_id is not None:
return self._partitions[self._input_partition_id]
return None
def get_output_partition(self):
if self._output_partition_id is not None:
return self._partitions[self._output_partition_id]
return None
def get_partition_by_id(self, partition_id):
return self._partitions[partition_id]
def __str__(self) -> str:
res = ''
if len(self._partitions) == 0:
return 'Empty Topo Graph.'
input_part = self.get_input_partition()
if input_part is not None:
res += '{\n'
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
res += '}\n'
mid_parts = self.get_mid_partitions()
for i, (partition_id, part) in enumerate(mid_parts.items()):
res += '{\n'
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
res += '}\n'
output_part = self.get_output_partition()
if output_part is not None:
res += '{\n'
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
res += '}\n'
return res
def __repr__(self) -> str:
return self.__str__()
\ No newline at end of file
......@@ -9,7 +9,7 @@ from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoParameter
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from .layer_sepc import LayerSpec
from .layer_spec import LayerSpec
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
......
......@@ -8,18 +8,28 @@ from typing import Any, Callable, Dict, List, Tuple
import torch
import torch.distributed.rpc as rpc
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
split_batch, tensor_shape_list, type_detail)
from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (
get_batch_lengths,
pyobj_map,
pytree_filter,
pytree_map,
split_batch,
tensor_shape_list,
type_detail,
)
class Phase(Enum):
FORWARD = 0
BACKWARD = 1
UPDATE = 2
INPUT = 3
class UniqueKey:
......@@ -134,6 +144,7 @@ class WorkerBase(ABC):
self.partition_args = partition_args
self.criterion = criterion
self.metric = metric
self.reset = False
# context to maintain loop
self._initialize_context_container()
......@@ -164,6 +175,7 @@ class WorkerBase(ABC):
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock())
self.reset_condition = threading.Condition(threading.Lock())
def _initialize_partition(self):
partition_fn = self.partition_fn
......@@ -173,6 +185,41 @@ class WorkerBase(ABC):
self.module_partition: nn.Module = partition_fn(*partition_args).to(device)
self.partition_condition_lock.notify_all()
def _get_output_all(self, key: UniqueKey, ref_use=False, rank=None):
with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key]
output = output_work_item.output
if not ref_use and output_work_item.phase != Phase.INPUT:
self.output_list.pop(key)
if not ref_use and output_work_item.phase != Phase.INPUT:
output_work_item.refcount += 1
refcount = output_work_item.refcount
# lifecycle management for DAG scheduler
if output_work_item.phase == Phase.FORWARD:
lifecycle = len(self.get_consumer_stage_ids())
if self.is_model_output(): # an extra reference for scheduler collecting results
lifecycle += 1
elif output_work_item.phase == Phase.BACKWARD:
lifecycle = len(self.get_producer_stage_ids())
if self.is_model_input() and self._is_last_step(
output_work_item): # an extra reference for ensure_backward
lifecycle += 1
else:
lifecycle = 0
refcount = 0
with self.output_list_condition_lock:
if refcount < lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
if isinstance(output, Future):
output = output.wait()
return output
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
......@@ -182,23 +229,21 @@ class WorkerBase(ABC):
# construction of partition is executed after the registion of pp_rank_to_worker_rref
self._initialize_partition()
def get_output_by_key(self, key: UniqueKey) -> Any:
with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key]
output = output_work_item.output
if isinstance(output, Future):
output = output.wait()
output_work_item.refcount += 1
# all consumers have been satisfied, the work_item can be released
with self.output_list_condition_lock:
if output_work_item.refcount >= len(self.consumer_stage_ids):
self.output_list.pop(key)
# res_use works for lifecycle counter,
# if ref_use is True, lifecycle won't add.
# offset supports get partial output to reduce comm costs.
def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any:
output = self._get_output_all(key, ref_use, rank)
if offsets is None: # get all for non iterable output
return output
else: # get part for iterable output
output = [output[i] for i in offsets]
return output
def get_numels(self) -> int:
numel = sum(param.numel() for param in self.module_partition.parameters())
return numel
def get_parameters(self) -> List[torch.Tensor]:
return [p for p in self.module_partition.parameters()]
......@@ -215,8 +260,10 @@ class WorkerBase(ABC):
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
return self.module_partition.state_dict()
def _make_args_kwargs(self, microbatch):
def _make_args_kwargs(self, microbatch, merge=False):
if isinstance(microbatch, dict):
if merge:
return list(microbatch.values()), {}
return [], microbatch
elif isinstance(microbatch, torch.Tensor):
return [microbatch], {}
......@@ -228,16 +275,21 @@ class WorkerBase(ABC):
kwargs.update(arg)
else:
args.append(arg)
if merge:
arg_lst = args
for arg in kwargs.values():
arg_lst.append(arg)
return arg_lst, {}
return args, kwargs
else:
raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}")
# just for first pp_rank
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
assert self.consumer_stage_ids is not None
key = UniqueKey(microbatch_id, Phase.FORWARD)
output = self._get_future_by_device()
if not self.use_middleware():
# make args and kwargs
args, kwargs = self._make_args_kwargs(microbatch)
......@@ -246,6 +298,35 @@ class WorkerBase(ABC):
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
else:
# make args and kwargs
arg_lst, _ = self._make_args_kwargs(microbatch, merge=True)
# first stage assign correct input into other stages
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
input_partition = topo.get_input_partition()
self_input_offsets = input_partition.get_output_offsets(self_partition_id)
recv_input_key = UniqueKey(microbatch_id, Phase.INPUT)
# set input for self rank
self_arg_lst = []
for off in self_input_offsets:
self_arg_lst.append(arg_lst[off])
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
# put input tensor which other nodes need into output_list as Phase.INPUT
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
self.num_microbatches, forward_only)
with self.output_list_condition_lock:
self.output_list[recv_input_key] = work_item_remote
self.output_list_condition_lock.notify_all()
# just for last pp_rank
def set_labels(self, microbatch_id: int, microlabels: Any):
......@@ -268,79 +349,372 @@ class WorkerBase(ABC):
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
def _subscribe_producer(self, microbatch_id: int, forward_only: bool):
"""
You should call this function asynchronously
"""
assert self.producer_stage_ids is not None
producer_num = len(self.producer_stage_ids)
assert producer_num > 0, "only stage that has producers can subscribe producers"
stage_id = self.pp_rank
subscribe_forward_futures: List[Future] = [None] * producer_num
output = self._get_future_by_device()
if not self.use_middleware():
producer_num = len(self.producer_stage_ids)
subscribe_forward_futures: List[Future] = [None] * producer_num
for i in range(producer_num):
producer_stage_id = self.producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
else:
producer_stage_ids = self.get_producer_stage_ids()
producer_num = len(producer_stage_ids)
if self.need_model_input():
producer_num += 1 # for input partition
subscribe_forward_futures: List[Future] = [None] * producer_num
# TODO(jiangziyue) get single value instead of the whole output
if self.need_model_input():
producer_stage_id = 0
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
offsets = self._get_input_offsets_by_index(target_index=0)
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key,
rank=self.pp_rank,
offsets=offsets)
for i in range(0, producer_num - 1):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
target_index = i + 1
offsets = self._get_input_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank)
else:
for i in range(producer_num):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
target_index = i
offsets = self._get_input_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank, offsets=offsets)
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
# add work_item to work_list
with self.work_list_condition_lock:
return work_item_from_producer
# TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one.
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list
with self.work_list_condition_lock:
if key not in self.work_list:
# On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer
# can only be executed once for every producer-consumer stage pair, which is necessary
# to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
work_item_from_producer = self._subscribe_producer(microbatch_id, forward_only)
self.work_list[key] = work_item_from_producer
self.work_list_condition_lock.notify_all()
def subscribe_consumer(self, microbatch_id: int):
def _subscribe_consumer(self, microbatch_id: int):
"""
You should call this function asynchronously
"""
assert self.producer_stage_ids is not None
consumer_num = len(self.consumer_stage_ids)
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
stage_id = self.pp_rank
subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device()
if not self.use_middleware():
consumer_stage_ids = self.consumer_stage_ids
else:
consumer_stage_ids = self.get_consumer_stage_ids()
consumer_num = len(consumer_stage_ids)
subscribe_backward_futures: List[Future] = [None] * consumer_num
for i in range(consumer_num):
consumer_stage_id = self.consumer_stage_ids[i]
consumer_stage_id = consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
target_index = i
offsets = self._get_output_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_backward_futures[target_index] = []
else:
subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key(
consumer_output_key, rank=self.pp_rank, offsets=offsets)
# flatten args
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
microbatch_id, None, self.num_microbatches, False)
# add work_item to work_list
with self.work_list_condition_lock:
return work_item_from_consumer
def subscribe_consumer(self, microbatch_id: int):
key = UniqueKey(microbatch_id, Phase.BACKWARD)
assert key not in self.work_list
with self.work_list_condition_lock:
if key not in self.work_list:
# On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer
# can only be executed once for every producer-consumer stage pair, which is necessary
# to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
work_item_from_consumer = self._subscribe_consumer(microbatch_id)
self.work_list[key] = work_item_from_consumer
self.work_list_condition_lock.notify_all()
def get_producer_stage_ids(self):
producer_stage_ids = []
rank = self.pp_rank
if not self.use_middleware():
prev_rank = rank - 1
if prev_rank >= 0:
producer_stage_ids.append(prev_rank)
else:
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
input_partition_ids = self_partition.get_input_partition_ids()
model_input_partition_id = topo.get_input_partition_id()
for partition_id in input_partition_ids:
# ignore input partition in current implementation.
# it will be specially tackled.
if partition_id != model_input_partition_id:
producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
return producer_stage_ids
def get_consumer_stage_ids(self):
consumer_stage_ids = []
rank = self.pp_rank
if not self.use_middleware():
next_rank = rank + 1
if next_rank <= self.actual_stage_num - 1:
consumer_stage_ids.append(next_rank)
else:
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
output_partition_ids = self_partition.get_output_partition_ids()
model_output_partition_id = topo.get_output_partition_id()
for partition_id in output_partition_ids:
if model_output_partition_id != partition_id:
consumer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
return consumer_stage_ids
def _get_producer_consumer(self) -> None:
rank = self.pp_rank
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
# should be aranged in order, the order of the input of current forward
self.producer_stage_ids = []
self.consumer_stage_ids = []
self.producer_stage_ids = self.get_producer_stage_ids()
self.consumer_stage_ids = self.get_consumer_stage_ids()
# Just for demo
prev_rank = rank - 1
next_rank = rank + 1
if prev_rank >= 0:
self.producer_stage_ids.append(prev_rank)
if next_rank <= self.actual_stage_num - 1:
self.consumer_stage_ids.append(next_rank)
def pp_rank_to_partition_id(self, pp_rank: int, topo: Topo):
partition_ids = topo.get_mid_partition_ids()
return partition_ids[pp_rank]
def partition_id_to_pp_rank(self, partition_id: int, topo: Topo):
partition_ids = topo.get_mid_partition_ids()
for i, id in enumerate(partition_ids):
if id == partition_id:
return i
def get_topo(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
if hasattr(self.module_partition, '_topo'):
return self.module_partition._topo
else:
return None
def use_middleware(self):
topo = self.get_topo()
return topo is not None
def _get_input_offsets_by_index(self, target_index):
res = []
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
model_input_partition_id = topo.get_input_partition_id()
input_vals = self_partition.get_input_vals()
producer_stage_ids = self.get_producer_stage_ids()
if self.need_model_input():
# 0 for data from input batch
# >= 1 for data from prev stages
base = 1
else:
# data from prev stages
base = 0
for val in input_vals:
val_pos = val.get()
src_partition_id = val_pos.partition_id
src_offset = val_pos.offset
src_index = base
src_partition = topo.get_partition_by_id(src_partition_id)
output_len = len(src_partition.get_output_vals())
# data from not-input partition
if src_partition_id != model_input_partition_id:
src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo)
src_index = base
for i, stage_id in enumerate(producer_stage_ids):
if stage_id == src_stage_id:
src_index += i
break
else: # data from input partition
src_index = 0
# when output_len = 1, not iterable
if target_index == src_index:
if output_len == 1:
res = None # offset = None to get all outputs
return res
else:
res.append(src_offset)
return res
def _get_output_offsets_by_index(self, target_index):
res = []
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
output_vals = self_partition.get_output_vals()
consumer_stage_ids = self.get_consumer_stage_ids()
for val_list in output_vals:
# An output may be passed to many down stages.
target = None
for val_pos in val_list.get():
dst_partition_id = val_pos.partition_id
dst_offset = val_pos.offset
dst_partition = topo.get_partition_by_id(dst_partition_id)
input_len = len(dst_partition.get_input_vals())
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
for i, stage_id in enumerate(consumer_stage_ids):
if stage_id == dst_stage_id:
dst_index = i
break
if target_index == dst_index:
if input_len == 1:
res = None # offset = None to get all outputs
return res
else:
res.append(dst_offset)
return res
# TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs_fwd(self, args_or_kwargs):
if not self.use_middleware():
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
else:
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
if self.is_first_stage():
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
else: # get by offset
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
model_input_partition_id = topo.get_input_partition_id()
input_vals = self_partition.get_input_vals()
producer_stage_ids = self.get_producer_stage_ids()
if self.need_model_input():
# 0 for data from input batch
# >= 1 for data from prev stages
base = 1
else:
# data from prev stages
base = 0
for val in input_vals:
val_pos = val.get()
src_partition_id = val_pos.partition_id
src_offset = val_pos.offset
src_index = base
src_partition = topo.get_partition_by_id(src_partition_id)
output_len = len(src_partition.get_output_vals())
# data from not-input partition
if src_partition_id != model_input_partition_id:
src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo)
src_index = base
for i, stage_id in enumerate(producer_stage_ids):
if stage_id == src_stage_id:
src_index += i
break
else: # data from input partition
src_index = 0
# when output_len = 1, not iterable
if output_len == 1:
target = args_or_kwargs[src_index]
else:
offsets = self._get_input_offsets_by_index(src_index)
real_offset = offsets.index(src_offset)
target = args_or_kwargs[src_index][real_offset]
flatten_args.append(target)
args_or_kwargs = flatten_args
return args_or_kwargs
# TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs_bwd(self, args_or_kwargs):
if not self.use_middleware():
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
else:
for i, arg in enumerate(args_or_kwargs):
args_or_kwargs[i] = arg.wait()
if args_or_kwargs is not None: # get by offset
flatten_args = []
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
output_vals = self_partition.get_output_vals()
consumer_stage_ids = self.get_consumer_stage_ids()
for val_list in output_vals:
# An output may be passed to many down stages.
target = None
for val_pos in val_list.get():
dst_partition_id = val_pos.partition_id
dst_offset = val_pos.offset
dst_partition = topo.get_partition_by_id(dst_partition_id)
input_len = len(dst_partition.get_input_vals())
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
for i, stage_id in enumerate(consumer_stage_ids):
if stage_id == dst_stage_id:
dst_index = i
break
if input_len == 1:
part_grad = args_or_kwargs[dst_index]
else:
offsets = self._get_output_offsets_by_index(dst_index)
real_offsets = offsets.index(dst_offset)
part_grad = args_or_kwargs[dst_index][real_offsets]
if target is None:
target = part_grad
elif part_grad is not None:
target += part_grad
else:
continue
flatten_args.append(target)
args_or_kwargs = flatten_args
return args_or_kwargs
@abstractmethod
def _get_work_item_key(self) -> UniqueKey:
......@@ -354,6 +728,23 @@ class WorkerBase(ABC):
def is_last_stage(self):
return self.pp_rank == self.actual_stage_num - 1
def need_model_input(self):
need_input = False
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition = topo.get_partition_by_id(self_partition_id)
partition_inputs = self_partition.get_input_partition_ids()
model_input_partition_id = topo.get_input_partition_id()
if model_input_partition_id in partition_inputs:
need_input = True
return not self.is_first_stage() and need_input
def is_model_output(self):
return self.is_last_stage()
def is_model_input(self):
return self.is_first_stage()
def _default_data_process_func(self, args_kwargs):
if self.is_first_stage():
args = args_kwargs[0]
......@@ -390,11 +781,16 @@ class WorkerBase(ABC):
# parse and integrate args and kwargs
if is_first_stage:
args = get_real_args_kwargs(args)
kwargs = get_real_args_kwargs(kwargs)
args = self._get_real_args_kwargs_fwd(args)
kwargs = self._get_real_args_kwargs_fwd(kwargs)
args_kwargs = (args, kwargs)
else:
args_kwargs = get_real_args_kwargs(args)
args_kwargs = self._get_real_args_kwargs_fwd(args)
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device,
process_types=torch.device) # change devices from last stage to current device
args, kwargs = data_process_func(args_kwargs)
......@@ -459,6 +855,9 @@ class WorkerBase(ABC):
stage_input_kwargs,
stage_outputs,
checkpoint=use_checkpoint)
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in
# if not forward_only, do the backward
if not forward_only:
if is_last_stage: # if it is the last stage, trigger backward automatic
......@@ -486,21 +885,43 @@ class WorkerBase(ABC):
# overlap recompute and future.wait
if not is_last_stage:
grad_tensors = get_real_args_kwargs(args)
grad_tensors = self._get_real_args_kwargs_bwd(args)
else:
grad_tensors = None
# take tensor only (for only tensor can do backward)
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor)
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor)
# TODO(jiangziyue) : All values which should do bp are torch.Tensor?
stage_outputs = pytree_filter(lambda x: True, stage_outputs, process_types=torch.Tensor)
grad_tensors = pytree_filter(lambda x: True, grad_tensors, process_types=torch.Tensor)
# output all input's grad to producer, even it has no grad(output None)
# to make the offset aligned to the topo's record.
if grad_tensors is not None:
filtered_outputs = []
filtered_grads = []
for i, grad in enumerate(grad_tensors):
stage_output = stage_outputs[i]
if stage_output.requires_grad and grad is not None:
filtered_outputs.append(stage_output)
filtered_grads.append(grad)
stage_outputs = filtered_outputs
grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor
consume_result = []
if not is_first_stage:
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
# In current design, input mush be a flatten args.
for arg in stage_input_args:
if isinstance(arg, torch.Tensor):
consume_result.append(arg.grad)
else:
consume_result.append(None)
consume_result = pyobj_map(
consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
......@@ -532,11 +953,11 @@ class WorkerBase(ABC):
def _hook_before_step(self):
pass
def _reset_context(self):
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self._initialize_outstanding_range()
# install the main loop to wait for next batch input
def _wait_for_reset(self):
with self.reset_condition:
self.reset_condition.wait_for(lambda: self.reset)
self.reset = False
# do the main loop to consume ready_list
def _work_loop(self):
......@@ -547,10 +968,10 @@ class WorkerBase(ABC):
# main loop
while True:
work_item_key = self._get_work_item_key()
# move current work item to output_list to activate subscribe in advance
with self.work_list_condition_lock:
work_item = self.work_list.pop(work_item_key)
self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
work_item = self.work_list[work_item_key]
with self.output_list_condition_lock:
# assert work_item_key not in self.output_list
......@@ -559,27 +980,37 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item)
with self.work_list_condition_lock:
self.work_list.pop(work_item_key)
work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step
if self._is_last_step(work_item):
self._hook_before_step()
if hasattr(self, 'optimizer') and not work_item.forward_only:
self.step()
self._reset_context()
self._wait_for_reset()
# reset context and resume loop
def reset_context(self):
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self._initialize_outstanding_range()
with self.work_list_condition_lock:
self.work_list.clear()
with self.output_list_condition_lock:
self.output_list.clear()
with self.reset_condition:
self.reset = True
self.reset_condition.notify_all()
def initialize_optimizer(self, optimizer_class: type, **kwargs):
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
self.step_lock = threading.Lock()
self.step_lock.acquire()
def wait_for_step(self):
self.step_lock.acquire()
def step(self):
self._hook_before_step()
self.optimizer.step()
self.optimizer.zero_grad()
self.step_lock.release()
class PipelineEngineBase(ABC, nn.Module):
......@@ -611,8 +1042,6 @@ class PipelineEngineBase(ABC, nn.Module):
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
self.step_futs: List[Future] = []
self._check_argument()
self._create_pp_rank_to_rpc_worker_id()
self._create_pp_rank_to_module_partition_id()
......@@ -692,6 +1121,15 @@ class PipelineEngineBase(ABC, nn.Module):
for fut in sync_futs:
fut.wait()
def remote_numels(self) -> Dict[int, int]:
numels = {}
actual_stage_num = self._get_actual_stage_num()
for stage_id in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[stage_id]
numel = worker_rref.rpc_sync().get_numels()
numels[stage_id] = numel
return numels
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
parameters = {}
actual_stage_num = self._get_actual_stage_num()
......@@ -728,9 +1166,14 @@ class PipelineEngineBase(ABC, nn.Module):
ret_future[pp_rank][microbatch_id - actual_stage_num].wait()
else:
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
futs = []
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().get_output_by_key(key)
fut = worker_rref.rpc_async().get_output_by_key(key, ref_use=True, offsets=[])
futs.append(fut)
for fut in futs:
fut.wait()
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
num_microbatches = self.num_microbatches
......@@ -748,6 +1191,7 @@ class PipelineEngineBase(ABC, nn.Module):
# TODO : add relationship between output_pp_ranks and parts of microlabels
worker_rref.remote().set_labels(microbatch_id, microlabels)
# TODO(jiangziyue) : get model output with single value, instead of merging into last stage.
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
key = UniqueKey(microbatch_id, Phase.FORWARD)
for pp_rank in output_pp_ranks:
......@@ -756,10 +1200,16 @@ class PipelineEngineBase(ABC, nn.Module):
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
if not forward_only:
backward_result = []
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
worker_rref.rpc_sync().get_output_by_key(key)
fut = worker_rref.rpc_async().get_output_by_key(
key, offsets=[]) # only ensure the res exists, no need for real data.
backward_result.append(fut)
for fut in backward_result:
fut.wait()
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
forward_result = []
......@@ -776,6 +1226,17 @@ class PipelineEngineBase(ABC, nn.Module):
return forward_result
def _reset_worker(self):
actual_stage_num = self._get_actual_stage_num()
reset_futs: List[Future] = []
for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
fut = worker_rref.rpc_async().reset_context()
reset_futs.append(fut)
for fut in reset_futs:
fut.wait()
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
batch_lengths = get_batch_lengths(batch)
batch_length = batch_lengths[0]
......@@ -800,7 +1261,7 @@ class PipelineEngineBase(ABC, nn.Module):
for microbatch_id in range(num_microbatches):
# control data input speed
# to prevent exceed of wait limitations
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
# self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
batch_start = microbatch_size * microbatch_id
batch_end = min(batch_start + microbatch_size, batch_length)
......@@ -824,11 +1285,9 @@ class PipelineEngineBase(ABC, nn.Module):
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
if not forward_only and hasattr(self, 'optimizer_class'):
# wait for all step
for pp_rank in self.pp_rank_to_worker_rref:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().wait_for_step()
self.step()
self._reset_worker() # reset worker attributes for next batch
return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs):
......@@ -839,10 +1298,11 @@ class PipelineEngineBase(ABC, nn.Module):
def step(self):
actual_stage_num = self._get_actual_stage_num()
step_futs: List[Future] = []
for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
fut = worker_rref.rpc_async().step()
self.step_futs.append(fut)
step_futs.append(fut)
for fut in self.step_futs:
for fut in step_futs:
fut.wait()
......@@ -3,11 +3,12 @@ from typing import Callable, Dict, List
import torch
import torch.distributed as dist
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage
# <strategy>PipelineEngine is the class for use
......@@ -86,12 +87,9 @@ class OneFOneBWorker(WorkerBase):
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
elif target_key.microbatch_id == num_microbatches - 1:
if target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
......
......@@ -6,11 +6,25 @@ from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.futures import Future
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any:
if isinstance(obj, process_types):
return fn(obj)
elif type(obj) is dict:
return {k: pyobj_map(obj[k], fn, process_types) for k in obj}
elif type(obj) is tuple:
return tuple(pyobj_map(o, fn, process_types) for o in obj)
elif type(obj) is list:
return list(pyobj_map(o, fn, process_types) for o in obj)
else:
return obj
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
"""process object recursively, like pytree
......@@ -137,5 +151,5 @@ def parse_args():
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--master_addr', type=str, default='localhost')
parser.add_argument('--master_port', type=str, default='29020')
parser.add_argument('--num_worker_threads', type=str, default=128)
parser.add_argument('--num_worker_threads', type=int, default=128)
return parser.parse_args()
......@@ -6,6 +6,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.layer.utils import CheckpointModule
from typing import List
from collections import OrderedDict
def _binary_partition(weights: List, start: int, end: int):
"""Returns the binary partition position of `weights`, given the start
......@@ -159,8 +160,10 @@ def build_kwargs_for_module(function, input_tensor, kw_dict):
kwargs_offset = 0
elif isinstance(input_tensor, torch.Tensor):
kwargs_offset = 1
else:
assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
elif isinstance(input_tensor, (tuple, OrderedDict)):
#assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
# Huggingface will take their own structures based on OrderedDict as the output
# between layers so we've to close this check.
kwargs_offset = len(input_tensor)
args_name_list = list(sig.parameters.keys())
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]}
......
from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
from .distspec import ShardSpec
from .distspec import ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor
from . import distspec
from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor
from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, ParamOpHookManager
from .colo_tensor import ColoTensor
from .comm_spec import CollectiveCommPattern, CommSpec
from . import distspec
from .compute_spec import ComputePattern, ComputeSpec
from .dist_spec_mgr import DistSpecManager
from .distspec import ReplicaSpec, ShardSpec
from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager
from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern'
'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict',
'merge_same_dim_mesh_list'
]
import torch
from typing import Optional
import torch
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec
def filter_colo_parameters(*args, **kwargs):
param_list = []
def get_colo_parameters(element) -> None:
if isinstance(element, list) or isinstance(element, tuple):
for e in element:
get_colo_parameters(e)
elif isinstance(element, dict):
raise RuntimeError("Found Dict: ColoParameter can't deal with complicated arguments.")
elif isinstance(element, ColoParameter):
param_list.append(element)
return
for a in args:
get_colo_parameters(a)
for v in kwargs.values():
get_colo_parameters(v)
def filter_args(func, *args):
return [arg for arg in args if func(arg)]
return param_list
def replace_args(args, kwargs, new_args):
......@@ -58,18 +75,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
if ParamOpHookManager.has_hook():
if ColoParamOpHookManager.has_hook():
if not func.__name__.startswith('__'):
if kwargs is None:
kwargs = {}
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
params = filter_colo_parameters(*args, **kwargs)
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values())
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ParamOpHookManager.post_op(params, ret)
ret = ColoParamOpHookManager.post_op(params, ret)
return ret
return super().__torch_function__(func, types, args, kwargs)
......
from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy
import torch
from functools import lru_cache
from typing import Callable, Optional, Set
import torch
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional, Set, Callable
from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec
from colossalai.tensor.process_group import ProcessGroup
from colossalai.tensor.tensor_spec import ColoTensorSpec
from .const import TensorType
from .op_wrapper import _COLOSSAL_OPS
@lru_cache(None)
......@@ -55,7 +57,7 @@ class ColoTensor(torch.Tensor):
The Colotensor can be initialized with a PyTorch tensor in the following ways.
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()))
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
......@@ -67,6 +69,8 @@ class ColoTensor(torch.Tensor):
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
"""
torch_major = int(torch.__version__.split('.')[0])
torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
"""
......@@ -100,7 +104,6 @@ class ColoTensor(torch.Tensor):
self.process_group = spec.pg
self._type = TensorType.NONMODEL
self._graph_node = None
def has_compute_spec(self) -> bool:
return self.compute_spec is not None
......@@ -114,7 +117,7 @@ class ColoTensor(torch.Tensor):
def set_process_group(self, pg: ProcessGroup):
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid.
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
Args:
pg (ProcessGroup): target pg
......@@ -124,10 +127,10 @@ class ColoTensor(torch.Tensor):
# if the new pg is the same as the old pg, just returns
if self.process_group == pg:
return
assert self.process_group.tp_world_size() == 1, \
"Can not set_process_group on a ColoTensor whose process_group has tp world group"
assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
"Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
assert self.dist_spec.placement.value == 'r', \
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"
"Can not set_process_group on a ColoTensor whose dist spec is not Replica"
self.process_group = pg
......@@ -166,6 +169,16 @@ class ColoTensor(torch.Tensor):
if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func]
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
# in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
if func is torch.Tensor.backward:
assert len(args) == 1 # only has 1 paramter
backward_tensor = torch.Tensor(args[0])
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in _get_my_nowrap_functions():
......
......@@ -23,9 +23,9 @@ def _all_gather(tensor, comm_spec):
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
]
tensor = tensor
group = process_group
dist.all_gather(tensor_list, tensor, group=group)
# without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous()
dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
......@@ -37,11 +37,10 @@ def _split(tensor, comm_spec):
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list:
if dist.get_rank() in rank_list:
tensor = tensor
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
output = torch.narrow(tensor, dim, start, length)
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
......@@ -69,17 +68,145 @@ def _all_to_all(tensor, comm_spec):
return output
def _all_reduce(tensor, comm_spec):
def _all_reduce(tensor, comm_spec, async_op=False):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor
def _mix_gather(tensor, comm_spec):
'''
Implement mix gather operation on device mesh based on information provided by comm_spec.
Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is
different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather
only does all-gather in one dimension.
Assume index of f and b target pairs are 'f' and 'b'
ShardingSpec => gather_dim, logical_process_axes
S0S1 => [b, f], (1, 0)
S1S0 => [b, f], (0, 1)
S01R => [f], (1, 1)
RS01 => [b], (1, 1)
Example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
S0S1:
leading_group_dim = 1
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
tensor_list = [(0,0),(0,1),(0,2),(0,3),(1,0),(1,1),(1,2),(1,3)] # [(slice_id_f, slice_id_b),...]
mesh_shape = (2,4)
cat_slice = [4,2]
tmp_tensor_list = [(...,shape[f],shape[b]*4,...),(...,shape[f],shape[b]*4,...)]
tmp_tensor_list[0] = torch.cat(((0,0),(0,1),(0,2),(0,3)), dim=b)
tmp_tensor_list[1] = torch.cat(((1,0),(1,1),(1,2),(1,3)), dim=b)
output = torch.cat((tmp_tensor_list[0],tmp_tensor_list[1]), dim=a)
S1S0:
leading_group_dim = 0
process_group = "[0, 4, 1, 5, 2, 6, 3, 7]"
tensor_list = [(0,0),(0,1),(1,0),(1,1),(2,0),(2,1),(3,0),(3,1)]
mesh_shape = (2,4)
cat_slice = [2,4]
tmp_tensor_list = [(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...)]
tmp_tensor_list[0] = torch.cat(((0,0),(0,1)), dim=b)
tmp_tensor_list[1] = torch.cat(((1,0),(1,1)), dim=b)
tmp_tensor_list[2] = torch.cat(((2,0),(2,1)), dim=b)
tmp_tensor_list[3] = torch.cat(((3,0),(3,1)), dim=b)
S10R:
leading_group_dim = 0
process_group = "[0, 4, 1, 5, 2, 6, 3, 7]"
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
S01R:
leading_group_dim = 1
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
'''
total_slices = comm_spec.device_mesh.mesh_shape[0]
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)]
leading_group_dim = comm_spec.logical_process_axes[0]
assert len(comm_spec.device_mesh.process_groups_dict) == 1
_, process_group = comm_spec.device_mesh.process_groups_dict[0][0]
process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim]
# Global all_gather
dist.all_gather(tensor_list, tensor, group=process_group)
# This is very ugly. I'm figuring out more elegant methods
tensor_list_sorted = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)
]
for i in range(total_slices):
tensor_list_sorted[i] = tensor_list[process_number_list[i]]
tensor_list = tensor_list_sorted
if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]:
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous()
else:
mesh_shape = comm_spec.device_meshes.mesh_shape
cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]]
tmp_tensor_shape = list(tensor.shape)
tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0]
tmp_tensor_shape = torch.Size(tmp_tensor_shape)
tmp_tensor_list = [
torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1])
]
for i in range(cat_slice[1]):
tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]),
comm_spec.gather_dim[0]).contiguous()
output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous()
return output
def _mix_split(tensor, comm_spec):
'''
Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent)
Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split
because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension.
Assume index of f and b target pairs are 'f' and 'b'
S0S1 => [b, f], (1, 0)
S1S0 => [b, f], (0, 1)
S01R => [f], (0, 0)
RS01 => [b], (0, 0)
Example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
mesh_shape = comm_spec.device_meshes.mesh_shape
dim = comm_spec.gather_dim
total_slices = comm_spec.device_mesh.mesh_shape[0]
# Get global rank
rank = dist.get_rank()
leading_group_dim = comm_spec.logical_process_axes[0]
process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim]
rank = process_number_list.index(rank)
if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]:
length = tensor.shape[dim[0]] // total_slices
start = length * rank
output = torch.narrow(tensor, dim[0], start, length).contiguous()
else:
tensor_shape = [tensor.shape[dim[0]], tensor.shape[dim[1]]]
rank_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]]
length = [tensor_shape[0] // rank_slice[0], tensor_shape[1] // rank_slice[1]]
start = [(rank % rank_slice[0]) * length[0], (rank // rank_slice[0]) * length[1]]
tmp_output = torch.narrow(tensor, dim[0], start[0], length[0]).contiguous()
output = torch.narrow(tmp_output, dim[1], start[1], length[1]).contiguous()
return output
class _ReduceGrad(torch.autograd.Function):
"""
A customized communication operation which forward is an identity operation,
......@@ -205,6 +332,22 @@ class _AllToAll(torch.autograd.Function):
return _all_to_all(grad_outputs, ctx.comm_spec), None
class _MixGatherForwardMixSplitBackward(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
return _mix_gather(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return _mix_gather(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return _mix_split(grad_output, ctx.comm_spec), None
def reduce_grad(input_, comm_spec):
return _ReduceGrad.apply(input_, comm_spec)
......@@ -225,12 +368,17 @@ def all_to_all(input_, comm_spec):
return _AllToAll.apply(input_, comm_spec)
def mixgather_forward_split_backward(input_, comm_spec):
return _MixGatherForwardMixSplitBackward.apply(input_, comm_spec)
class CollectiveCommPattern(Enum):
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd"
class CommSpec:
......@@ -256,7 +404,8 @@ class CommSpec:
gather_dim=None,
shard_dim=None,
logical_process_axis=None,
forward_only=False):
forward_only=False,
mix_gather=False):
self.comm_pattern = comm_pattern
self.sharding_spec = sharding_spec
self.gather_dim = gather_dim
......@@ -264,8 +413,14 @@ class CommSpec:
self.logical_process_axis = logical_process_axis
self.forward_only = forward_only
if isinstance(self.logical_process_axis, list):
if not mix_gather:
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.logical_process_axis = 0
else:
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
# Create a new member `logical_process_axes` to distinguish from original flatten
self.logical_process_axes = logical_process_axis
else:
self.device_mesh = self.sharding_spec.device_mesh
......@@ -290,6 +445,10 @@ class CommSpec:
elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
res_list.append(f"comm_pattern:MIXGATHER_FWD_SPLIT_BWD, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"logical_process_asex:{self.logical_process_axes})")
return ''.join(res_list)
......@@ -325,6 +484,11 @@ class CommSpec:
forward_communication_cost = 10
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
# no need for axis because all devices are used in mix_gather
forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size)
backward_communication_cost = 10
if self.forward_only:
cost_dict["forward"] = forward_communication_cost
cost_dict["backward"] = 0
......@@ -357,4 +521,5 @@ pattern_to_func_dict = {
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward,
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input,
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad,
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: mixgather_forward_split_backward,
}
from colossalai.tensor.distspec import _DistSpec
# from colossalai.nn.layer.utils import divide
from numpy import prod
from contextlib import contextmanager
import torch
import torch.distributed as dist
# from colossalai.nn.layer.utils import divide
from numpy import prod
from packaging import version
from colossalai.logging import get_dist_logger
from colossalai.tensor import ProcessGroup
from colossalai.tensor.distspec import _DistSpec
from colossalai.tensor.process_group import ProcessGroup
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
......
import torch
from contextlib import contextmanager
from abc import ABC, abstractmethod
from typing import List, Tuple, Any
from contextlib import contextmanager
from typing import Any, List, Tuple
import torch
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor.tensor_spec import ColoTensorSpec
class ParamOpHook(ABC):
"""Hook which is triggered by each operation when operands contain ColoParameter.
class ColoParamOpHook(ABC):
"""
Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
of ColoParameter.
``post_forward``, ``pre_backward`` and ``post_backward``.
These four methods apply a list of ColoParameter as input args.
"""
@abstractmethod
......@@ -30,68 +33,79 @@ class ParamOpHook(ABC):
pass
class ParamOpHookManager:
"""Manage your param op hooks. It only has static methods.
class ColoParamOpHookManager:
"""
Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
"""
hooks: Tuple[ParamOpHook, ...] = tuple()
hooks: Tuple[ColoParamOpHook, ...] = tuple()
@staticmethod
@contextmanager
def use_hooks(*hooks: ParamOpHook):
def use_hooks(*hooks: ColoParamOpHook):
"""Change the param op hooks you use. Nested calling is allowed.
Example:
>>> with ParamOpHookManager.use_hooks(*hooks):
>>> with ColoParamOpHookManager.use_hooks(*hooks):
>>> do_something()
>>> with ParamOpHookManager.use_hooks():
>>> with ColoParamOpHookManager.use_hooks():
>>> // clear hooks
>>> do_something()
"""
try:
old_param_op_hooks = ParamOpHookManager.hooks
ParamOpHookManager.hooks = hooks
old_param_op_hooks = ColoParamOpHookManager.hooks
ColoParamOpHookManager.hooks = hooks
yield
finally:
ParamOpHookManager.hooks = old_param_op_hooks
ColoParamOpHookManager.hooks = old_param_op_hooks
@staticmethod
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
for hook in ColoParamOpHookManager.hooks:
hook.pre_forward(params)
@staticmethod
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
for hook in ColoParamOpHookManager.hooks:
hook.post_forward(params)
@staticmethod
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
for hook in ColoParamOpHookManager.hooks:
hook.pre_backward(params)
@staticmethod
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
for hook in ColoParamOpHookManager.hooks:
hook.post_backward(params)
@staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ParamOpHookManager._trigger_pre_forward(params)
args_info = _get_colo_tensors_info(*args)
rets = PreFwdPostBwd.apply(params, *args)
return _update_colo_tensors(args_info, *rets)
ColoParamOpHookManager._trigger_pre_forward(params)
grad_args, rear_args = _get_grad_args(*args)
colo_info = _get_colo_tensors_info(*grad_args)
rets = PreFwdPostBwd.apply(params, *grad_args)
update_args = _update_colo_tensors(colo_info, *rets)
if rear_args is None:
return update_args
else:
arg_zero = (tuple(update_args),)
return arg_zero + rear_args
@staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ParamOpHookManager._trigger_post_forward(params)
arg_info = _get_colo_tensors_info(arg)
ColoParamOpHookManager._trigger_post_forward(params)
colo_info = _get_colo_tensors_info(arg)
ret = PostFwdPreBwd.apply(params, arg)
return _unpack_args(_update_colo_tensors(arg_info, ret))
res = _update_colo_tensors(colo_info, ret)
if len(res) == 1:
return res[0]
else:
return res
@staticmethod
def has_hook() -> bool:
return len(ParamOpHookManager.hooks) > 0
return len(ColoParamOpHookManager.hooks) > 0
class PreFwdPostBwd(torch.autograd.Function):
......@@ -99,11 +113,11 @@ class PreFwdPostBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
ctx.params = params
return _unpack_args(args)
return args
@staticmethod
def backward(ctx, *grads):
ParamOpHookManager._trigger_post_backward(ctx.params)
ColoParamOpHookManager._trigger_post_backward(ctx.params)
return (None,) + grads
......@@ -116,14 +130,51 @@ class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
def backward(ctx, *grads):
ParamOpHookManager._trigger_pre_backward(ctx.params)
ColoParamOpHookManager._trigger_pre_backward(ctx.params)
return (None,) + grads
def _unpack_args(args):
if len(args) == 1:
return args[0]
return args
def _is_grad_tensor(obj) -> bool:
if torch.is_tensor(obj):
if obj.grad_fn is not None or obj.requires_grad:
return True
return False
def _has_grad_tensor(obj) -> bool:
if isinstance(obj, tuple) or isinstance(obj, list):
for x in obj:
if _has_grad_tensor(x):
return True
return False
elif isinstance(obj, dict):
for x in obj.values():
if _has_grad_tensor(x):
return True
return False
else:
return _is_grad_tensor(obj)
def _get_grad_args(*args):
# if there is no grad tensors, do nothing
if not _has_grad_tensor(args):
return args, None
# returns the identical args if there is a grad tensor
for obj in args:
if _is_grad_tensor(obj):
return args, None
# otherwise, the first arguement should be a tuple of grad tensors
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
arg_zero = args[0]
if not isinstance(arg_zero, tuple):
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
check_grad_flag = False
for obj in arg_zero:
check_grad_flag |= _is_grad_tensor(obj)
if not check_grad_flag:
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
return arg_zero, args[1:]
def _get_colo_tensors_info(*args) -> list:
......
import math
import operator
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from functools import reduce
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator
from .comm_spec import *
......@@ -28,6 +25,15 @@ class ShapeConsistencyOptions:
pass
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor:
shape_consistency_manager = ShapeConsistencyManager()
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
with torch.no_grad():
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec,
global_sharding_spec)
return global_tensor
def set_shape_consistency_options(options: ShapeConsistencyOptions):
"""
Configure the shape consistency manager via function call.
......@@ -63,7 +69,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
assert isinstance(value, bool)
self._forward_only = value
def get_all_all_gather_spec(self, source_spec, orig_cost_dict):
def get_all_all_gather_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
'''
Get all valid sharding specs from source_spec with single all-gather operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
......@@ -71,7 +78,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation.
orig_cost(Dict[str, float]): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
......@@ -83,7 +90,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
......@@ -134,7 +141,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass
return valid_spec_dict
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict):
def get_all_all_to_all_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
'''
Get all valid sharding specs from source_spec with single all-to-all operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
......@@ -142,7 +150,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation.
orig_cost(Dict[str, float]): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
......@@ -154,7 +162,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0)
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
......@@ -241,7 +249,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
return valid_spec_dict
def get_all_shard_spec(self, source_spec, orig_cost_dict):
def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict):
'''
Get all valid sharding specs from source_spec with single shard operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
......@@ -261,7 +269,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0)
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
......@@ -322,7 +330,60 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict):
def get_all_mix_gather_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
'''
S0S1 -> RR
S1S0 -> RR
S01R -> RR
RS01 -> RR
'''
valid_spec_dict = {}
comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
tensor_dims = len(source_spec.entire_shape)
for f_index in range(tensor_dims - 1):
for b_index in range(f_index + 1, tensor_dims):
if (f_index not in source_spec.dim_partition_dict) and (b_index not in source_spec.dim_partition_dict):
continue
else:
if f_index in source_spec.dim_partition_dict:
# skip (S10, R) -> (R, R)
if len(f_target_pair[1]) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]:
continue
f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
else:
f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict:
# skip (R, S10) -> (R, R)
if len(b_target_pair[1]) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]:
continue
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
else:
b_target_pair = (b_index, [])
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
comm_spec = CommSpec(comm_pathern,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
forward_only=self.forward_only,
mix_gather=True)
cost_dict = comm_spec.get_comm_cost()
new_dim_partition_dict = {}
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
except ShardingSpecException:
pass
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
'''
Get all valid sharding specs from source_spec with one step transform, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
......@@ -344,7 +405,167 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
return valid_spec_dict
def shape_consistency(self, source_spec, target_spec):
def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem:
"""memory cost of the communication action sequence
Args:
comm_action_sequence (List[CommSpec]): list of communication actions
Returns:
TrainCycleItem: memory (numel) cost of such comm_action_sequence
"""
def compute_shape(sharding_spec: ShardingSpec):
shape = sharding_spec.entire_shape
new_shape = []
for dim, shard in sharding_spec.dim_partition_dict.items():
new_shape.append(shape[dim] // len(shard))
return new_shape
def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_gather memory footprint
all_gather will allocate memory for the output tensor, and there will be temp memory for
all_gather operation, which is twice the size of output tensor
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
alloc_numel += output_numel
if discard_input:
alloc_numel -= input_numel
return alloc_numel, peak_numel
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze split memory footprint
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not
generate new tensor in this case, so no memory will be allocated.
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
shard_dim = comm_spec.shard_dim
if shard_dim != 0:
# if we don't shard the tensor on the first dimension, the split action will
# generate a new tensor
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
alloc_numel += output_numel
peak_numel = max(peak_numel, alloc_numel)
if discard_input:
alloc_numel -= input_numel
else:
# if we shard the tensor on the first dimension, the split action will not generate
# a new tensor, and as it will preserve a reference to the input tensor, we could
# override the discard_input option here
# NOTE: this special case might fail in some weird cases, e.g. if we have three split
# actions in the comm actions sequence, the first split action operate on the second dimension,
# the second split action operate on the first dimension, and the third split action operate, again,
# on the second dimension. Therefore, after the first two actions in the sequence, we will allocate
# memory the same size as the output of first split action. However, the third split action will discard
# the input tensor, and it actually should discard the tensor generated by the first split action, so in
# the current memory estimation framework, we will overestimate the memory usage. But the above case is
# kind of weird, and I think we could ignore it for now.
pass
return alloc_numel, peak_numel
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
"""
return alloc_numel, peak_numel
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_to_all memory footprint
all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action
is twice the size of output tensor if we shard input tensor on the first dimension, otherwise
the temp memory is three times the size of output tensor
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel
shard_dim = comm_spec.shard_dim
if shard_dim != 0:
peak_numel = max(peak_numel, alloc_numel + output_numel * 3)
else:
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
alloc_numel += output_numel
if discard_input:
alloc_numel -= input_numel
return alloc_numel, peak_numel
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
"""
return alloc_numel, peak_numel
pattern_to_func_dict = {
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: [all2all_analysis, all2all_analysis],
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: [split_analysis, gather_analysis],
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: [reduce_analysis, identity_analysis],
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: [identity_analysis, reduce_analysis],
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: [],
}
fwd_actions = []
bwd_actions = []
# construct forward and backward comm actions sequence
for comm_spec in comm_action_sequence:
comm_spec: CommSpec
fwd_action, bwd_action = pattern_to_func_dict[comm_spec.comm_pattern]
fwd_actions.append(fwd_action)
bwd_actions.append(bwd_action)
# analyze memory footprint of forward comm actions sequence
fwd_alloc_numel = 0
fwd_peak_numel = 0
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
# the first forward comm action will not discard input
fwd_action, comm_spec = action_spec_pair
fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel,
fwd_peak_numel) if idx == 0 else fwd_action(
comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
# analyze memory footprint for backward comm actions sequence
bwd_alloc_numel = 0
bwd_peak_numel = 0
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
bwd_action, comm_spec = action_spec_pair
bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel,
bwd_peak_numel) if idx == 0 else bwd_action(
comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
total_mem = MemoryCost(activation=fwd_alloc_numel + bwd_alloc_numel)
return TrainCycleItem(fwd_mem, bwd_mem, total_mem)
def shape_consistency(self, source_spec: ShardingSpec,
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
'''
This method will find a path to transform source_spec to target_spec with
a greedy algorithm.
......@@ -450,7 +671,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
def apply(self, tensor_with_sharding_spec, target_spec):
def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor:
'''
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
shape_consistency method.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment