Unverified Commit fef5c949 authored by Ziyue Jiang's avatar Ziyue Jiang Committed by GitHub
Browse files

polish pp middleware (#2476)


Co-authored-by: default avatarZiyue Jiang <ziyue.jiang@gmail.com>
parent a5dc4253
......@@ -211,7 +211,7 @@ class WorkerBase(ABC):
refcount = 0
with self.output_list_condition_lock:
if refcount < lifecycle:
if refcount <= lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
......@@ -390,7 +390,7 @@ class WorkerBase(ABC):
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank)
producer_output_key, rank=self.pp_rank, offsets=offsets)
else:
for i in range(producer_num):
......
......@@ -29,9 +29,6 @@ class FillDrainWorker(WorkerBase):
target_key = UniqueKey(target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
......
......@@ -120,7 +120,7 @@ def run_master(args):
logger.info(f'{rank=} numel in the partition:{numel}')
# build optim
pp_engine.initialize_optimizer(HybridAdam, lr=1e-3)
pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
ranks_tflops = {}
for n in range(NUM_STEPS):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment