You need to sign in or sign up before continuing.
Unverified Commit 4a63034e authored by Mehdi Mirzazadeh's avatar Mehdi Mirzazadeh Committed by GitHub
Browse files

checking number parameters in distributed pipeline test (#728)

parent bcd4748d
...@@ -243,7 +243,9 @@ def multi_input_multi_output_layers(devices): ...@@ -243,7 +243,9 @@ def multi_input_multi_output_layers(devices):
pipe = DistributedPipeline(graph, chunks=4) pipe = DistributedPipeline(graph, chunks=4)
assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe) assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe)
opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,) parameter_rrefs = pipe.parameter_rrefs()
assert len(parameter_rrefs) == 6
opt = DistributedOptimizer(torch.optim.SGD, parameter_rrefs, lr=0.05,)
losses = [] losses = []
for i in range(2): for i in range(2):
with dist_autograd.context() as context_id: with dist_autograd.context() as context_id:
...@@ -293,7 +295,9 @@ def auto_graph_extract(devices): ...@@ -293,7 +295,9 @@ def auto_graph_extract(devices):
pipe = DistributedPipeline(graph, chunks=4) pipe = DistributedPipeline(graph, chunks=4)
partitions = extract_partitions(graph, pipe) partitions = extract_partitions(graph, pipe)
assert [[0, 1], [2], [3], [4]] == partitions, f"partitions={partitions}" assert [[0, 1], [2], [3], [4]] == partitions, f"partitions={partitions}"
opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,) parameter_rrefs = pipe.parameter_rrefs()
assert len(parameter_rrefs) == 6
opt = DistributedOptimizer(torch.optim.SGD, parameter_rrefs, lr=0.05,)
losses = [] losses = []
for i in range(2): for i in range(2):
with dist_autograd.context() as context_id: with dist_autograd.context() as context_id:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment