Unverified Commit 5a6fd71f authored by Kirigaya Kazuto's avatar Kirigaya Kazuto Committed by GitHub
Browse files

[pipeline/rpc] update outstanding mechanism | optimize dispatching strategy (#1497)

* support p2p communication with any type of object | pass test

* reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* [pipeline/rpc] implement a demo for PP with cuda rpc framework

* [pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B

* [pipeline/rpc] implement distributed optimizer | test with assert_close

* [pipeline/rpc] implement distributed optimizer | test with assert_close

* [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy

* [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy

* [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy
parent 0ed2f461
...@@ -68,7 +68,7 @@ class UniqueKey: ...@@ -68,7 +68,7 @@ class UniqueKey:
class WorkItem: class WorkItem:
__slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id', __slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id',
'num_microbatches') 'num_microbatches', 'forward_only')
stage_id: int stage_id: int
phase: Phase phase: Phase
...@@ -81,6 +81,7 @@ class WorkItem: ...@@ -81,6 +81,7 @@ class WorkItem:
batch_id: int batch_id: int
num_microbatches: int num_microbatches: int
forward_only: bool
def __init__(self, def __init__(self,
stage_id, stage_id,
...@@ -91,6 +92,7 @@ class WorkItem: ...@@ -91,6 +92,7 @@ class WorkItem:
microbatch_id, microbatch_id,
batch_id, batch_id,
num_microbatches, num_microbatches,
forward_only,
refcount=0) -> None: refcount=0) -> None:
for attr_name in self.__slots__: for attr_name in self.__slots__:
setattr(self, attr_name, locals()[attr_name]) setattr(self, attr_name, locals()[attr_name])
...@@ -129,36 +131,39 @@ class Worker: ...@@ -129,36 +131,39 @@ class Worker:
pp_rank: int, pp_rank: int,
actual_stage_num: int, actual_stage_num: int,
num_microbatches: int, num_microbatches: int,
max_outstanding: int, use_1F1B: bool,
device: str, device: str,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
super().__init__() super().__init__()
self.pp_rank = pp_rank self.pp_rank = pp_rank
self.actual_stage_num = actual_stage_num self.actual_stage_num = actual_stage_num
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.max_outstanding = max_outstanding
self.outstanding = 0
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.device = device self.device = device
self.outstanding_range = self._initialize_outstanding_range(pp_rank, actual_stage_num, use_1F1B)
self.future_devices = None if device is None or device == 'cpu' else [device] # variable and const for context managment
self.outstanding = 0
self.forward_times = 0
self.backward_times = 0
self.reset_key = UniqueKey(0, Phase.FORWARD)
# rref of other workers
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None
# topology info
self.producer_stage_ids: List[int] = None self.producer_stage_ids: List[int] = None
self.consumer_stage_ids: List[int] = None self.consumer_stage_ids: List[int] = None
# module # module partitions
self.module_partition = module_partition.to(device) self.module_partition = module_partition.to(device)
self.debug_list = [None] * num_microbatches # container to maintain loop
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict() self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
self.work_list: Dict[UniqueKey, WorkItem] = dict() self.work_list: Dict[UniqueKey, WorkItem] = dict()
self.output_list: Dict[UniqueKey, WorkItem] = dict() self.output_list: Dict[UniqueKey, WorkItem] = dict()
# Why must a Lock instead of RLock ? # lock for the list
# Because RLock cannot be pickled
self.work_list_condition_lock = threading.Condition(threading.Lock()) self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock())
...@@ -168,6 +173,15 @@ class Worker: ...@@ -168,6 +173,15 @@ class Worker:
def _get_future_by_device(self): def _get_future_by_device(self):
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device]) return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
def _initialize_outstanding_range(self, pp_rank: int, actual_stage_num: int, use_1F1B: bool) -> Tuple[int]:
outstanding_range = None
if use_1F1B:
if pp_rank == actual_stage_num - 1:
outstanding_range = (0, 1)
else:
outstanding_range = (actual_stage_num, actual_stage_num)
return outstanding_range
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
...@@ -197,8 +211,15 @@ class Worker: ...@@ -197,8 +211,15 @@ class Worker:
def get_parameter_gradients(self) -> List[torch.Tensor]: def get_parameter_gradients(self) -> List[torch.Tensor]:
return [p.grad for p in self.module_partition.parameters()] return [p.grad for p in self.module_partition.parameters()]
def reset_pp_context(self):
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self.microbatch_id_to_backward_cache.clear()
self.output_list.clear()
# just for first pp_rank # just for first pp_rank
def set_input(self, microbatch_id: int, microbatch: Tuple[Any]): def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
with self.work_list_condition_lock: with self.work_list_condition_lock:
assert self.consumer_stage_ids is not None assert self.consumer_stage_ids is not None
consumer_num = len(self.consumer_stage_ids) consumer_num = len(self.consumer_stage_ids)
...@@ -207,11 +228,10 @@ class Worker: ...@@ -207,11 +228,10 @@ class Worker:
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None, work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, consumer_num) self.num_microbatches, forward_only)
self.work_list[key] = work_item self.work_list[key] = work_item
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta') color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
# just for last pp_rank # just for last pp_rank
...@@ -224,24 +244,22 @@ class Worker: ...@@ -224,24 +244,22 @@ class Worker:
grad_wrt_loss = torch.tensor(1, device=self.device) grad_wrt_loss = torch.tensor(1, device=self.device)
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, producer_num) self.num_microbatches, False)
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta') color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
self.work_list[key] = work_item self.work_list[key] = work_item
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
def subscribe_producer(self, microbatch_id: int): def subscribe_producer(self, microbatch_id: int, forward_only: bool):
""" """
You should call this function asynchronously You should call this function asynchronously
""" """
assert self.producer_stage_ids is not None assert self.producer_stage_ids is not None
producer_num = len(self.producer_stage_ids) producer_num = len(self.producer_stage_ids)
consumer_num = len(self.consumer_stage_ids)
assert producer_num > 0, "only stage that has producers can subscribe producers" assert producer_num > 0, "only stage that has producers can subscribe producers"
stage_id = self.pp_rank stage_id = self.pp_rank
subscribe_forward_futures: List[Future] = [None] * producer_num subscribe_forward_futures: List[Future] = [None] * producer_num
output = self._get_future_by_device() output = self._get_future_by_device()
...@@ -259,9 +277,8 @@ class Worker: ...@@ -259,9 +277,8 @@ class Worker:
producer_args = subscribe_forward_futures[i].wait() producer_args = subscribe_forward_futures[i].wait()
args.extend(producer_args) args.extend(producer_args)
# TODO : not only args
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None, work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, consumer_num) self.num_microbatches, forward_only)
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list # add work_item to work_list
...@@ -279,13 +296,10 @@ class Worker: ...@@ -279,13 +296,10 @@ class Worker:
You should call this function asynchronously You should call this function asynchronously
""" """
assert self.producer_stage_ids is not None assert self.producer_stage_ids is not None
producer_num = len(self.producer_stage_ids)
consumer_num = len(self.consumer_stage_ids) consumer_num = len(self.consumer_stage_ids)
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers" assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
# TODO : is this right?
stage_id = self.pp_rank stage_id = self.pp_rank
subscribe_backward_futures: List[Future] = [None] * consumer_num subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device() output = self._get_future_by_device()
...@@ -305,7 +319,7 @@ class Worker: ...@@ -305,7 +319,7 @@ class Worker:
# flatten args # flatten args
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None, work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, producer_num) self.num_microbatches, False)
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
...@@ -341,32 +355,57 @@ class Worker: ...@@ -341,32 +355,57 @@ class Worker:
while len(self.work_list) == 0: while len(self.work_list) == 0:
self.work_list_condition_lock.wait() self.work_list_condition_lock.wait()
# each stage must do Key(microbatch_id=0, phase=FORWARD) first
# before doing the operation, reset the context first
if self.reset_key in self.work_list:
self.reset_pp_context()
# execute backward first (if backward phase in work_list) # execute backward first (if backward phase in work_list)
select_work_list_key = None pp_rank = self.pp_rank
for key in self.work_list: actual_stage_num = self.actual_stage_num
work_item = self.work_list[key] num_microbatches = self.num_microbatches
if work_item.phase == Phase.FORWARD and \ is_last_stage = pp_rank == actual_stage_num - 1
self.max_outstanding is not None and \ select_work_list_key: UniqueKey = None
self.outstanding >= self.max_outstanding:
continue if self.outstanding_range:
if self.outstanding <= self.outstanding_range[0]:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
elif self.outstanding >= self.outstanding_range[1]:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else: else:
if select_work_list_key is not None and \ raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
select_work_list_key.phase == Phase.FORWARD and \
key.phase == Phase.BACKWARD: target_key = UniqueKey(target_microbatch_id, target_phase)
continue if target_key in self.work_list:
select_work_list_key = target_key
if select_work_list_key is None:
select_work_list_key = key # change outstanding_range at:
else: # 1. forward times reach actual_stage_num, this is the end of continuous forward
phase_pair = (select_work_list_key.phase, key.phase) # 2. forward times reach num_microbatches, this is the end of 1F1B mode
# choose forward first if not is_last_stage and \
if phase_pair == (Phase.BACKWARD, Phase.FORWARD): select_work_list_key is not None and \
select_work_list_key = key select_work_list_key.phase == Phase.FORWARD:
elif phase_pair == (Phase.FORWARD, Phase.BACKWARD): if select_work_list_key.microbatch_id == actual_stage_num - 1:
continue outstanding_min = actual_stage_num - pp_rank - 1
# choose work_item which has a smaller microbactch_id first outstanding_max = actual_stage_num - pp_rank
elif key.microbatch_id < select_work_list_key.microbatch_id: self.outstanding_range = (outstanding_min, outstanding_max)
select_work_list_key = key elif select_work_list_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
else:
if self.forward_times < num_microbatches:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
target_key = UniqueKey(target_microbatch_id, target_phase)
if target_key in self.work_list:
select_work_list_key = target_key
return select_work_list_key return select_work_list_key
...@@ -375,15 +414,28 @@ class Worker: ...@@ -375,15 +414,28 @@ class Worker:
args = work_item.args args = work_item.args
kwargs = work_item.kwargs kwargs = work_item.kwargs
microbatch_id = work_item.microbatch_id microbatch_id = work_item.microbatch_id
forward_only = work_item.forward_only
consume_result = None consume_result = None
# if self.pp_rank == 0: # TODO : use process manager to acquire rank info later
# print(f"I am rank_{self.pp_rank} microbatch_id : {microbatch_id}", work_item.phase, len(self.work_list)) is_first_stage = (self.pp_rank == 0)
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
# color_debug(f'rank_{self.pp_rank} enter consume', 'consume', 'blue') # if self.pp_rank == 3:
# print(
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
# )
if phase == Phase.FORWARD: if phase == Phase.FORWARD:
self.outstanding += 1 # remind its consumer to get data before forward
if not is_last_stage:
for stage_id in self.consumer_stage_ids:
consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id]
consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only)
self.forward_times += 1
if not forward_only:
self.outstanding += 1
# TODO : more elegant ? # TODO : more elegant ?
for i in range(len(args)): for i in range(len(args)):
...@@ -391,35 +443,46 @@ class Worker: ...@@ -391,35 +443,46 @@ class Worker:
if isinstance(arg_obj, torch.Tensor) and not arg_obj.requires_grad: if isinstance(arg_obj, torch.Tensor) and not arg_obj.requires_grad:
args[i] = arg_obj.requires_grad_() args[i] = arg_obj.requires_grad_()
# TODO : use process manager to acquire rank info later
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
# last stage doesn't need to do checkpoint, for it will do backward instantly # last stage doesn't need to do checkpoint, for it will do backward instantly
if self.checkpoint and not is_last_stage: if forward_only:
with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs)
stage_outputs = None
stage_inputs = None
use_checkpoint = None
elif self.checkpoint and not is_last_stage:
with torch.no_grad(): with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs) consume_result = self.module_partition(*args, **kwargs)
stage_outputs = None stage_outputs = None
stage_inputs = args stage_inputs = args
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs, use_checkpoint = True
stage_outputs,
checkpoint=True)
else: else:
consume_result = self.module_partition(*args, **kwargs) consume_result = self.module_partition(*args, **kwargs)
stage_outputs = consume_result stage_outputs = consume_result
stage_inputs = args stage_inputs = args
use_checkpoint = False
if not forward_only:
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs, self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
stage_outputs, stage_outputs,
checkpoint=False) checkpoint=use_checkpoint)
consume_result = [consume_result] if isinstance(consume_result, torch.Tensor) else consume_result consume_result = [consume_result] if isinstance(consume_result, torch.Tensor) else consume_result
# if it is the last stage, trigger backward automatic # if not forward_only, do the backward
if is_last_stage: if not forward_only:
self._begin_backward(microbatch_id) if is_last_stage: # if it is the last stage, trigger backward automatic
self._begin_backward(microbatch_id)
elif phase == Phase.BACKWARD: elif phase == Phase.BACKWARD:
# remind its producer to get data before backward
if not is_first_stage:
for stage_id in self.producer_stage_ids:
producer_worker_rref = self.pp_rank_to_worker_rref[stage_id]
producer_worker_rref.remote().subscribe_consumer(microbatch_id)
self.backward_times += 1
self.outstanding -= 1 self.outstanding -= 1
assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache" assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache"
backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id) backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id)
...@@ -445,6 +508,9 @@ class Worker: ...@@ -445,6 +508,9 @@ class Worker:
return consume_result return consume_result
def _get_store_len(self):
return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)}'
# do the main loop to consume ready_list # do the main loop to consume ready_list
def _work_loop(self): def _work_loop(self):
# for init # for init
...@@ -461,7 +527,7 @@ class Worker: ...@@ -461,7 +527,7 @@ class Worker:
work_item = self.work_list.pop(work_item_key) work_item = self.work_list.pop(work_item_key)
color_debug( color_debug(
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}', f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
'work loop', 'green') 'work loop', 'green')
with self.output_list_condition_lock: with self.output_list_condition_lock:
...@@ -472,7 +538,7 @@ class Worker: ...@@ -472,7 +538,7 @@ class Worker:
consume_result = self._consume_work_item_by_phase(work_item) consume_result = self._consume_work_item_by_phase(work_item)
color_debug( color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}', f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}',
'work loop', 'green') 'work loop', 'green')
work_item.output.set_result(consume_result) work_item.output.set_result(consume_result)
...@@ -489,9 +555,6 @@ class Worker: ...@@ -489,9 +555,6 @@ class Worker:
self.optimizer.zero_grad() self.optimizer.zero_grad()
# TODO
# 1. chunk
# 2. checkpoint
class PipelineEngineBase(ABC, nn.Module): class PipelineEngineBase(ABC, nn.Module):
def __init__(self, def __init__(self,
...@@ -499,19 +562,18 @@ class PipelineEngineBase(ABC, nn.Module): ...@@ -499,19 +562,18 @@ class PipelineEngineBase(ABC, nn.Module):
stage_num, stage_num,
num_microbatches, num_microbatches,
device: str, device: str,
max_outstanding=None, use_1F1B=False,
chunk: int = 1, chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
super().__init__() super().__init__()
self.module_partitions: List[nn.Module] = module_partitions self.module_partitions: List[nn.Module] = module_partitions
self.chunk = chunk self.chunk = chunk
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.device = device self.device = device
self.max_outstanding = max_outstanding self.use_1F1B = use_1F1B
self.stage_num = stage_num self.stage_num = stage_num
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.use_interleave = use_interleave self.use_interleave = chunk > 1
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
...@@ -547,7 +609,7 @@ class PipelineEngineBase(ABC, nn.Module): ...@@ -547,7 +609,7 @@ class PipelineEngineBase(ABC, nn.Module):
def _init_worker(self): def _init_worker(self):
actual_stage_num = self._get_actual_stage_num() actual_stage_num = self._get_actual_stage_num()
max_outstanding = self.max_outstanding use_1F1B = self.use_1F1B
checkpoint = self.checkpoint checkpoint = self.checkpoint
num_microbatches = self.num_microbatches num_microbatches = self.num_microbatches
device = self.device device = self.device
...@@ -560,8 +622,7 @@ class PipelineEngineBase(ABC, nn.Module): ...@@ -560,8 +622,7 @@ class PipelineEngineBase(ABC, nn.Module):
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
Worker, Worker,
args=(module_partition, pp_rank, actual_stage_num, args=(module_partition, pp_rank, actual_stage_num,
num_microbatches, max_outstanding, device, num_microbatches, use_1F1B, device, checkpoint))
checkpoint))
# let each worker know global worker rref (include itself) # let each worker know global worker rref (include itself)
for pp_rank in range(actual_stage_num): for pp_rank in range(actual_stage_num):
...@@ -585,46 +646,55 @@ class PipelineEngineBase(ABC, nn.Module): ...@@ -585,46 +646,55 @@ class PipelineEngineBase(ABC, nn.Module):
grads[stage_id].append(grad) grads[stage_id].append(grad)
return grads return grads
def forward_backward(self, batch: torch.Tensor): def forward_backward(self, batch: torch.Tensor, forward_only: bool = False):
first_stage_worker = self.pp_rank_to_worker_rref[0] num_microbatches = self.num_microbatches
microbatch_size = len(batch) // self.num_microbatches microbatch_size = len(batch) // num_microbatches
actual_stage_num = self._get_actual_stage_num() actual_stage_num = self._get_actual_stage_num()
microbatch_iter = range(self.num_microbatches) first_stage_worker = self.pp_rank_to_worker_rref[0]
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
microbatch_iter = range(num_microbatches)
if use_progress: if use_progress:
microbatch_iter = tqdm(microbatch_iter) microbatch_iter = tqdm(microbatch_iter)
ret_future: List[Future] = [None] * num_microbatches
from time import sleep
for microbatch_id in microbatch_iter: for microbatch_id in microbatch_iter:
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
# forward subscribe asynchronously # control data input speed
for pp_rank in range(1, actual_stage_num, 1): # to prevent exceed of wait limitations
worker_rref = self.pp_rank_to_worker_rref[pp_rank] if microbatch_id >= actual_stage_num:
worker_rref.rpc_async().subscribe_producer(microbatch_id) if forward_only or not self.use_1F1B:
ret_future[microbatch_id - actual_stage_num].wait()
# backward subscribe asynchronously else:
for pp_rank in range(actual_stage_num - 2, -1, -1): key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
worker_rref = self.pp_rank_to_worker_rref[pp_rank] first_stage_worker.rpc_sync().get_output_by_key(key)
worker_rref.rpc_async().subscribe_consumer(microbatch_id)
# run one microbatch # run one microbatch
first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch) first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch, forward_only)
key = UniqueKey(microbatch_id, Phase.FORWARD)
ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key)
# wait forward # wait forward
# TODO : all the node to output # TODO : all the node to output
forward_result = None forward_result = None
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
for microbatch_id in range(self.num_microbatches): for microbatch_id in range(self.num_microbatches):
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
ret = last_worker_rref.rpc_sync().get_output_by_key(key) ret = ret_future[microbatch_id].wait()
if forward_result is None: if forward_result is None:
forward_result = [[]] * len(ret) forward_result = [[]] * len(ret)
for i in range(len(forward_result)): for i in range(len(forward_result)):
forward_result[i].append(ret[i]) forward_result[i].append(ret[i])
# wait for last backward in rank0 # wait for last backward in rank0
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) if not forward_only:
first_stage_worker.rpc_sync().get_output_by_key(key) key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
first_stage_worker.rpc_sync().get_output_by_key(key)
return forward_result return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs): def initialize_optimizer(self, optimizer_class: type, **kwargs):
...@@ -654,11 +724,9 @@ class FillDrainPipelineEngine(PipelineEngineBase): ...@@ -654,11 +724,9 @@ class FillDrainPipelineEngine(PipelineEngineBase):
num_microbatches: int, num_microbatches: int,
device: str, device: str,
chunk: int = 1, chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
max_outstanding = None use_1F1B = False
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave, super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint)
checkpoint)
class OneFOneBPipelineEngine(PipelineEngineBase): class OneFOneBPipelineEngine(PipelineEngineBase):
...@@ -668,11 +736,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase): ...@@ -668,11 +736,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
stage_num: int, stage_num: int,
num_microbatches: int, num_microbatches: int,
device: str, device: str,
max_outstanding=None,
chunk: int = 1, chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
if max_outstanding is None: use_1F1B = True
max_outstanding = len(module_partitions) super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint)
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave,
checkpoint)
...@@ -5,13 +5,9 @@ import torch ...@@ -5,13 +5,9 @@ import torch
from torch import nn from torch import nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
from torch import autograd
from torch.optim import SGD, Adam, RMSprop, Optimizer from torch.optim import SGD, Adam, RMSprop, Optimizer
from colorama import Back, Style from colorama import Back, Style
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.testing import assert_close
def color_debug(text, prefix=' ', color='blue'): def color_debug(text, prefix=' ', color='blue'):
color = color.upper() color = color.upper()
...@@ -43,13 +39,13 @@ class RpcTestModel(nn.Module): ...@@ -43,13 +39,13 @@ class RpcTestModel(nn.Module):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=1)
parser.add_argument('--world_size', type=int, default=2) parser.add_argument('--world_size', type=int, default=2)
parser.add_argument('--num_microbatches', type=int, default=2) parser.add_argument('--num_microbatches', type=int, default=2)
parser.add_argument('--chunk', type=int, default=1) parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--use_checkpoint', action='store_true') parser.add_argument('--use_checkpoint', action='store_true')
parser.add_argument('--use_interleave', action='store_true')
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD')
parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--master_addr', type=str, default='localhost') parser.add_argument('--master_addr', type=str, default='localhost')
parser.add_argument('--master_port', type=str, default='29020') parser.add_argument('--master_port', type=str, default='29020')
parser.add_argument('--num_worker_threads', type=str, default=128) parser.add_argument('--num_worker_threads', type=str, default=128)
......
import os
import argparse
import torch import torch
from torch import nn from torch import nn
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from torch import autograd from torch import autograd
from torch.optim import SGD, Adam, RMSprop, Optimizer from torch.optim import SGD, Adam, RMSprop, Optimizer
from colorama import Back, Style
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.testing import assert_close from colossalai.testing import assert_close
...@@ -21,7 +15,6 @@ def run_master(args): ...@@ -21,7 +15,6 @@ def run_master(args):
stage_num = args.world_size stage_num = args.world_size
chunk = args.chunk chunk = args.chunk
actual_stage_num = stage_num * chunk actual_stage_num = stage_num * chunk
use_interleave = args.use_interleave
use_checkpoint = args.use_checkpoint use_checkpoint = args.use_checkpoint
num_microbatches = args.num_microbatches num_microbatches = args.num_microbatches
optimizer_class = globals()[args.optimizer] optimizer_class = globals()[args.optimizer]
...@@ -45,7 +38,6 @@ def run_master(args): ...@@ -45,7 +38,6 @@ def run_master(args):
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
device=device, device=device,
chunk=chunk, chunk=chunk,
use_interleave=use_interleave,
checkpoint=use_checkpoint) checkpoint=use_checkpoint)
engine.initialize_optimizer(optimizer_class, lr=lr) engine.initialize_optimizer(optimizer_class, lr=lr)
......
import os
import argparse
import torch import torch
from torch import nn from torch import nn
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
from rpc_test_utils import rpc_run, parse_args, RpcTestModel from rpc_test_utils import rpc_run, parse_args, RpcTestModel
...@@ -13,12 +8,12 @@ from rpc_test_utils import rpc_run, parse_args, RpcTestModel ...@@ -13,12 +8,12 @@ from rpc_test_utils import rpc_run, parse_args, RpcTestModel
def run_master(args): def run_master(args):
torch.manual_seed(100) torch.manual_seed(100)
epoch = args.epoch
device = args.device device = args.device
stage_num = args.world_size stage_num = args.world_size
chunk = args.chunk chunk = args.chunk
num_microbatches = args.num_microbatches num_microbatches = args.num_microbatches
actual_stage_num = stage_num * chunk actual_stage_num = stage_num * chunk
use_interleave = args.use_interleave
use_checkpoint = args.use_checkpoint use_checkpoint = args.use_checkpoint
sample_num = 1024 sample_num = 1024
...@@ -38,10 +33,10 @@ def run_master(args): ...@@ -38,10 +33,10 @@ def run_master(args):
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
device=device, device=device,
chunk=chunk, chunk=chunk,
use_interleave=use_interleave,
checkpoint=use_checkpoint) checkpoint=use_checkpoint)
_ = engine.forward_backward(input_sample) for _ in range(epoch):
_ = engine.forward_backward(input_sample, forward_only=False)
if __name__ == "__main__": if __name__ == "__main__":
......
import os
import argparse
import torch import torch
from torch import nn from torch import nn
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from torch import autograd from torch import autograd
from colorama import Back, Style
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.testing import assert_close from colossalai.testing import assert_close
...@@ -20,7 +14,6 @@ def run_master(args): ...@@ -20,7 +14,6 @@ def run_master(args):
stage_num = args.world_size stage_num = args.world_size
chunk = args.chunk chunk = args.chunk
actual_stage_num = stage_num * chunk actual_stage_num = stage_num * chunk
use_interleave = args.use_interleave
use_checkpoint = args.use_checkpoint use_checkpoint = args.use_checkpoint
num_microbatches = args.num_microbatches num_microbatches = args.num_microbatches
...@@ -41,7 +34,6 @@ def run_master(args): ...@@ -41,7 +34,6 @@ def run_master(args):
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
device=device, device=device,
chunk=chunk, chunk=chunk,
use_interleave=use_interleave,
checkpoint=use_checkpoint) checkpoint=use_checkpoint)
forward_result = engine.forward_backward(input_sample) forward_result = engine.forward_backward(input_sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment