Commit 7f22f90e authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Remove xformers

parent afdbe5d3
...@@ -2,7 +2,6 @@ from typing import Optional ...@@ -2,7 +2,6 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import xformers.ops as xops
from cacheflow import ops from cacheflow import ops
from cacheflow.models import InputMetadata from cacheflow.models import InputMetadata
...@@ -14,8 +13,20 @@ class OPTCacheFlowAttention(nn.Module): ...@@ -14,8 +13,20 @@ class OPTCacheFlowAttention(nn.Module):
super().__init__() super().__init__()
self.scale = scale self.scale = scale
# Shape-agnostic attention mask. def _masked_attention(
self.attention_mask = xops.LowerTriangularMask() self,
query: torch.Tensor, # [num_queries, num_heads, head_size]
key: torch.Tensor, # [num_keys, num_heads, head_size]
value: torch.Tensor, # [num_keys, num_heads, head_size]
attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys]
) -> torch.Tensor: # [num_queries, num_heads, head_size]
query = query * self.scale
attn = torch.einsum('qhd,khd->hqk', query, key)
if attn_mask is not None:
attn = attn + attn_mask
attn = torch.softmax(attn, dim=-1)
out = torch.einsum('hqk,khd->qhd', attn, value)
return out
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
...@@ -24,13 +35,11 @@ class OPTCacheFlowAttention(nn.Module): ...@@ -24,13 +35,11 @@ class OPTCacheFlowAttention(nn.Module):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
) -> None: ) -> None:
query = query.unsqueeze(0) # FIXME(woosuk): Replace this with a custom op call.
key = key.unsqueeze(0) attention_mask = torch.triu(
value = value.unsqueeze(0) torch.ones(query.shape[0], key.shape[0]), diagonal=1) * -1e5
out = xops.memory_efficient_attention( attention_mask = attention_mask.to(dtype=query.dtype, device=query.device)
query, key, value, attn_bias=self.attention_mask, scale=self.scale) out = self._masked_attention(query, key, value, attention_mask)
out = out.squeeze(0)
# FIXME(woosuk): Directly write the attention output.
output.copy_(out, non_blocking=True) output.copy_(out, non_blocking=True)
def single_query_cached_kv_attention( def single_query_cached_kv_attention(
...@@ -64,15 +73,10 @@ class OPTCacheFlowAttention(nn.Module): ...@@ -64,15 +73,10 @@ class OPTCacheFlowAttention(nn.Module):
v = value_cache[block_number, :, block_offset, :] v = value_cache[block_number, :, block_offset, :]
values.append(v) values.append(v)
keys = torch.stack(keys, dim=0) keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0) values = torch.stack(values, dim=0)
q = q.unsqueeze(0) out = self._masked_attention(q, keys, values)
keys = keys.unsqueeze(0)
values = values.unsqueeze(0)
out = xops.memory_efficient_attention(
q, keys, values, scale=self.scale)
out = out.view(num_heads, head_size) out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True) output[i].copy_(out, non_blocking=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment