Unverified Commit 8bf97719 authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

[Bugfix] Fix performance bug in edge softmax operator (#624)

parent 7ad0485a
...@@ -126,7 +126,8 @@ class EdgeSoftmax1(th.autograd.Function): ...@@ -126,7 +126,8 @@ class EdgeSoftmax1(th.autograd.Function):
g.edata[grad_score_name] = out * grad_out g.edata[grad_score_name] = out * grad_out
g.update_all(fn.copy_e(grad_score_name, 'm'), fn.sum('m', accum_name)) g.update_all(fn.copy_e(grad_score_name, 'm'), fn.sum('m', accum_name))
g.apply_edges(fn.e_mul_v(out_name, accum_name, out_name)) g.apply_edges(fn.e_mul_v(out_name, accum_name, out_name))
grad_score = g.edata[grad_score_name] - g.edata[out_name] g.ndata.pop(accum_name)
grad_score = g.edata.pop(grad_score_name) - g.edata.pop(out_name)
return None, grad_score return None, grad_score
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment