Unverified Commit 4d9b76a8 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix RWKV backward on GPU (#23774)

parent 8d28dba3
...@@ -159,7 +159,7 @@ class RwkvLinearAttention(torch.autograd.Function): ...@@ -159,7 +159,7 @@ class RwkvLinearAttention(torch.autograd.Function):
@staticmethod @staticmethod
# g stands for grad # g stands for grad
def backward(ctx, g_output): def backward(ctx, g_output, g_state=None):
input_dtype = ctx.input_dtype input_dtype = ctx.input_dtype
time_decay, time_first, key, value, output = ctx.saved_tensors time_decay, time_first, key, value, output = ctx.saved_tensors
...@@ -188,17 +188,14 @@ class RwkvLinearAttention(torch.autograd.Function): ...@@ -188,17 +188,14 @@ class RwkvLinearAttention(torch.autograd.Function):
g_key, g_key,
g_value, g_value,
) )
g_time_decay = torch.sum(g_time_decay, dim=0)
g_time_first = torch.sum(g_time_first, dim=0)
return ( return (
None,
None,
None,
g_time_decay.to(input_dtype), g_time_decay.to(input_dtype),
g_time_first.to(input_dtype), g_time_first.to(input_dtype),
g_key.to(input_dtype), g_key.to(input_dtype),
g_value.to(input_dtype), g_value.to(input_dtype),
None,
None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment