"tests/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "80364c7686873a8cf28099d540d706d68e8fa194"
Unverified Commit 31fffd3f authored by Boyuan Yao's avatar Boyuan Yao Committed by GitHub
Browse files

[fx] fix wrong variable name in solver rotor (#1502)

* [fx] fix wrong variable name in solver rotor

* [fx] fix wrong variable name in solver rotor

* code modification
parent 3b6a5e25
...@@ -73,7 +73,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple: ...@@ -73,7 +73,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
return (opt, what) return (opt, what)
def _rec(chain, lmin, lmax, cmem, opt_table): def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
""" chain : the class describing the AC graph """ chain : the class describing the AC graph
lmin : index of the first forward to execute lmin : index of the first forward to execute
lmax : upper bound index of the last forward to execute (not included) lmax : upper bound index of the last forward to execute (not included)
...@@ -97,14 +97,14 @@ def _rec(chain, lmin, lmax, cmem, opt_table): ...@@ -97,14 +97,14 @@ def _rec(chain, lmin, lmax, cmem, opt_table):
if what[cmem][lmin][lmax][0]: if what[cmem][lmin][lmax][0]:
sequence.insert(ForwardEnable(lmin)) sequence.insert(ForwardEnable(lmin))
sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweigth[lmin + 1], opt_table)) sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table))
sequence.insert(Backward(lmin)) sequence.insert(Backward(lmin))
else: else:
j = what[cmem][lmin][lmax][1] j = what[cmem][lmin][lmax][1]
sequence.insert(ForwardCheck(lmin)) sequence.insert(ForwardCheck(lmin))
for k in range(lmin + 1, j): for k in range(lmin + 1, j):
sequence.insert(ForwardNograd(k)) sequence.insert(ForwardNograd(k))
sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweigth[j], opt_table)) sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table))
sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table)) sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table))
return sequence return sequence
......
...@@ -44,9 +44,9 @@ class Forward(Operation): ...@@ -44,9 +44,9 @@ class Forward(Operation):
def __repr__(self): def __repr__(self):
return "{n}_{i}".format(n=self.name, i=self.index) return "{n}_{i}".format(n=self.name, i=self.index)
def cost(self, chain): def cost(self, chain: Chain):
if chain is not None: if chain is not None:
return chain.fweigth[self.index] return chain.fweight[self.index]
else: else:
return 1 return 1
...@@ -80,9 +80,9 @@ class Forwards(Operation): ...@@ -80,9 +80,9 @@ class Forwards(Operation):
def __repr__(self): def __repr__(self):
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1]) return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
def cost(self, chain): def cost(self, chain: Chain):
if chain is not None: if chain is not None:
return sum(chain.fweigth[self.index[0]:self.index[1] + 1]) return sum(chain.fweight[self.index[0]:self.index[1] + 1])
else: else:
return (self.index[1] - self.index[0] + 1) return (self.index[1] - self.index[0] + 1)
...@@ -99,9 +99,9 @@ class Backward(Operation): ...@@ -99,9 +99,9 @@ class Backward(Operation):
def __repr__(self): def __repr__(self):
return "B_{i}".format(i=self.index) return "B_{i}".format(i=self.index)
def cost(self, chain): def cost(self, chain: Chain):
if chain is not None: if chain is not None:
return chain.bweigth[self.index] return chain.bweight[self.index]
else: else:
return 1 return 1
...@@ -126,7 +126,7 @@ class MemoryAccess(Operation): ...@@ -126,7 +126,7 @@ class MemoryAccess(Operation):
def __repr__(self): def __repr__(self):
return "{n}_{i}".format(n=self.name, i=self.index) return "{n}_{i}".format(n=self.name, i=self.index)
def cost(self, chain): def cost(self, chain: Chain):
return 0 return 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment