Commit ba84b872 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Fix attention

parent 87e0bcd4
from typing import Optional, Tuple from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -24,8 +24,12 @@ class OPTCacheFlowAttention(nn.Module): ...@@ -24,8 +24,12 @@ class OPTCacheFlowAttention(nn.Module):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
) -> None: ) -> None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention( out = xops.memory_efficient_attention(
query, key, value, attn_bias=self.attention_mask, scale=self.scale) query, key, value, attn_bias=self.attention_mask, scale=self.scale)
out = out.squeeze(0)
# FIXME(woosuk): Directly write the attention output. # FIXME(woosuk): Directly write the attention output.
output.copy_(out, non_blocking=True) output.copy_(out, non_blocking=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment