"...text-generation-inference.git" did not exist on "aac64ddaea91f6d342566c5a47cfb53c487eb769"
Unverified Commit 4a63034e authored by Mehdi Mirzazadeh's avatar Mehdi Mirzazadeh Committed by GitHub
Browse files

checking number parameters in distributed pipeline test (#728)

parent bcd4748d
...@@ -243,7 +243,9 @@ def multi_input_multi_output_layers(devices): ...@@ -243,7 +243,9 @@ def multi_input_multi_output_layers(devices):
pipe = DistributedPipeline(graph, chunks=4) pipe = DistributedPipeline(graph, chunks=4)
assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe) assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe)
opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,) parameter_rrefs = pipe.parameter_rrefs()
assert len(parameter_rrefs) == 6
opt = DistributedOptimizer(torch.optim.SGD, parameter_rrefs, lr=0.05,)
losses = [] losses = []
for i in range(2): for i in range(2):
with dist_autograd.context() as context_id: with dist_autograd.context() as context_id:
...@@ -293,7 +295,9 @@ def auto_graph_extract(devices): ...@@ -293,7 +295,9 @@ def auto_graph_extract(devices):
pipe = DistributedPipeline(graph, chunks=4) pipe = DistributedPipeline(graph, chunks=4)
partitions = extract_partitions(graph, pipe) partitions = extract_partitions(graph, pipe)
assert [[0, 1], [2], [3], [4]] == partitions, f"partitions={partitions}" assert [[0, 1], [2], [3], [4]] == partitions, f"partitions={partitions}"
opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,) parameter_rrefs = pipe.parameter_rrefs()
assert len(parameter_rrefs) == 6
opt = DistributedOptimizer(torch.optim.SGD, parameter_rrefs, lr=0.05,)
losses = [] losses = []
for i in range(2): for i in range(2):
with dist_autograd.context() as context_id: with dist_autograd.context() as context_id:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment