Unverified Commit 6878e422 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[hotfix] solver bug caused by dict type comm cost (#1686)

parent 3dd69944
......@@ -16,7 +16,6 @@ ELEMENTWISE_FUNC_OP = [
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
torch.flatten,
# softmax should not be here
torch.nn.functional.softmax
]
......
......@@ -69,6 +69,7 @@ class ReshapeHandler(OperatorHandler):
shape_consistency_manager = ShapeConsistencyManager()
_, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
replicate_input_sharding_spec)
communication_cost = communication_cost["total"]
# generate resharding cost
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
......
......@@ -319,6 +319,8 @@ class Solver:
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
assert len(s[i]) == len(d[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment