Unverified Commit 681606f0 authored by Mehdi Mirzazadeh's avatar Mehdi Mirzazadeh Committed by GitHub
Browse files

fixing bug in setting dependencies in partition handler (#723)

* fixing bug in setting dependancies in parition handler

* modifying unit test to need the fix

* black
parent bc1e60e0
......@@ -108,15 +108,13 @@ class DistributedPipelineRecord:
# with this constraint, replace the condition 'self.rank > 0' below with
# a more accurate one.
if chunk != 0 and self.consumers and self.rank > 0:
dependant_tensors = []
batch = self.batches[chunk]
assert batch is not None
for tensor, remote_ph_list in zip(batch.tensors, self.forwarded_phony[chunk - 1]):
dependant = tensor
dependant_tensors = list(batch.tensors)
for remote_ph_list in self.forwarded_phony[chunk - 1]:
for remote_ph in remote_ph_list:
phony = remote_ph.to_here()
dependant = join(dependant, phony)
dependant_tensors.append(dependant)
dependant_tensors[0] = join(dependant_tensors[0], phony)
self.batches[chunk] = Batch(tuple(dependant_tensors), chunk)
def sync_stream(self, chunk: int, stream: torch.cuda.Stream) -> None:
......
......@@ -289,14 +289,16 @@ def auto_graph_extract(devices):
# create model
model = nn.Sequential(
RemoteModule(devices[0], nn.Linear, (4, 4), {}), ShardedLinearLayer(devices[0], devices, devices[1])
RemoteModule(devices[0], nn.Linear, (4, 4), {}),
ShardedLinearLayer(devices[0], devices, devices[1]),
RemoteModule(devices[0], nn.Linear, (4, 4), {}),
)
graph = make_graph(model)
pipe = DistributedPipeline(graph, chunks=4)
partitions = extract_partitions(graph, pipe)
assert [[0, 1], [2], [3], [4]] == partitions, f"partitions={partitions}"
assert [[0, 1], [2], [3], [4], [5]] == partitions, f"partitions={partitions}"
parameter_rrefs = pipe.parameter_rrefs()
assert len(parameter_rrefs) == 6
assert len(parameter_rrefs) == 8
opt = DistributedOptimizer(torch.optim.SGD, parameter_rrefs, lr=0.05,)
losses = []
for i in range(2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment