Unverified Commit 6878e422 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[hotfix] solver bug caused by dict type comm cost (#1686)

parent 3dd69944
...@@ -16,7 +16,6 @@ ELEMENTWISE_FUNC_OP = [ ...@@ -16,7 +16,6 @@ ELEMENTWISE_FUNC_OP = [
torch.multiply, torch.multiply,
torch.nn.functional.relu, torch.nn.functional.relu,
torch.nn.functional.dropout, torch.nn.functional.dropout,
torch.flatten,
# softmax should not be here # softmax should not be here
torch.nn.functional.softmax torch.nn.functional.softmax
] ]
......
...@@ -69,6 +69,7 @@ class ReshapeHandler(OperatorHandler): ...@@ -69,6 +69,7 @@ class ReshapeHandler(OperatorHandler):
shape_consistency_manager = ShapeConsistencyManager() shape_consistency_manager = ShapeConsistencyManager()
_, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec, _, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
replicate_input_sharding_spec) replicate_input_sharding_spec)
communication_cost = communication_cost["total"]
# generate resharding cost # generate resharding cost
resharding_costs = self._generate_resharding_costs([input_sharding_spec]) resharding_costs = self._generate_resharding_costs([input_sharding_spec])
......
...@@ -319,6 +319,8 @@ class Solver: ...@@ -319,6 +319,8 @@ class Solver:
obj = 0 obj = 0
for i in range(node_nums): for i in range(node_nums):
assert len(s[i]) == len(c[i]) assert len(s[i]) == len(c[i])
assert len(s[i]) == len(d[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
############################################# #############################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment