Unverified Commit a9c83bce authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Fix #3437 (#3440)

parent 579cd3eb
...@@ -102,7 +102,7 @@ def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y): ...@@ -102,7 +102,7 @@ def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y):
def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y): def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argX in SpMM forward stage.""" """Rules to identify whether to cache argX in SpMM forward stage."""
if req_grad_X: if req_grad_X or req_grad_Y:
if reduce_op in ['min', 'max']: if reduce_op in ['min', 'max']:
return True return True
return False return False
...@@ -110,7 +110,7 @@ def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y): ...@@ -110,7 +110,7 @@ def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):
def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y): def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argY in SpMM forward stage.""" """Rules to identify whether to cache argY in SpMM forward stage."""
if req_grad_Y: if req_grad_X or req_grad_Y:
if reduce_op in ['min', 'max']: if reduce_op in ['min', 'max']:
return True return True
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment