Commit 45dc04f3 authored by thomwolf's avatar thomwolf
Browse files

tf model [WIP]

parent 24831477
......@@ -111,7 +111,7 @@ class MultiHeadAttention(torch.nn.Module):
v = self.split_into_heads(v, batch_size)
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
k = torch.cat((past_key, k), dim=-1)
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
present = torch.stack((k, v))
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment