Commit ba84b872 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Fix attention

parent 87e0bcd4
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn as nn
......@@ -24,8 +24,12 @@ class OPTCacheFlowAttention(nn.Module):
key: torch.Tensor,
value: torch.Tensor,
) -> None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention(
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
out = out.squeeze(0)
# FIXME(woosuk): Directly write the attention output.
output.copy_(out, non_blocking=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment