"...text-generation-inference.git" did not exist on "27ff1871b507e4f163d7fc6991915f6bb7057f92"
Unverified Commit a9c83bce authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Fix #3437 (#3440)

parent 579cd3eb
......@@ -102,7 +102,7 @@ def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y):
def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argX in SpMM forward stage."""
if req_grad_X:
if req_grad_X or req_grad_Y:
if reduce_op in ['min', 'max']:
return True
return False
......@@ -110,7 +110,7 @@ def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):
def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argY in SpMM forward stage."""
if req_grad_Y:
if req_grad_X or req_grad_Y:
if reduce_op in ['min', 'max']:
return True
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment