"...text-generation-inference.git" did not exist on "a5593ba83ef6d2edd3406497e3ed0573a86e44b6"
Commit 9bf31c6b authored by AntoinePrv's avatar AntoinePrv
Browse files

Replace saved_variables to saved_tensors.

parent db574780
...@@ -18,7 +18,7 @@ class ScatterDiv(Function): ...@@ -18,7 +18,7 @@ class ScatterDiv(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
out, src, index = ctx.saved_variables out, src, index = ctx.saved_tensors
grad_src = None grad_src = None
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
......
...@@ -19,7 +19,7 @@ class ScatterMax(Function): ...@@ -19,7 +19,7 @@ class ScatterMax(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out, grad_arg): def backward(ctx, grad_out, grad_arg):
index, arg = ctx.saved_variables index, arg = ctx.saved_tensors
grad_src = None grad_src = None
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
......
...@@ -19,7 +19,7 @@ class ScatterMin(Function): ...@@ -19,7 +19,7 @@ class ScatterMin(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out, grad_arg): def backward(ctx, grad_out, grad_arg):
index, arg = ctx.saved_variables index, arg = ctx.saved_tensors
grad_src = None grad_src = None
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
......
...@@ -18,7 +18,7 @@ class ScatterMul(Function): ...@@ -18,7 +18,7 @@ class ScatterMul(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
out, src, index = ctx.saved_variables out, src, index = ctx.saved_tensors
grad_src = None grad_src = None
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment