Commit 762fd1c3 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Refactor and annotate types for attention

parent 7f22f90e
from typing import Optional from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -30,24 +30,34 @@ class OPTCacheFlowAttention(nn.Module): ...@@ -30,24 +30,34 @@ class OPTCacheFlowAttention(nn.Module):
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
output: torch.Tensor, output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
prompt_lens: List[int],
) -> None: ) -> None:
# FIXME(woosuk): Replace this with a custom op call. # FIXME(woosuk): Replace the following with a custom op.
attention_mask = torch.triu( start_idx = 0
torch.ones(query.shape[0], key.shape[0]), diagonal=1) * -1e5 for prompt_len in prompt_lens:
attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) out = output[start_idx:start_idx + prompt_len]
out = self._masked_attention(query, key, value, attention_mask) q = query[start_idx:start_idx + prompt_len]
output.copy_(out, non_blocking=True) k = key[start_idx:start_idx + prompt_len]
v = value[start_idx:start_idx + prompt_len]
attention_mask = torch.triu(
torch.ones(q.shape[0], k.shape[0]), diagonal=1) * -1e5
attention_mask = attention_mask.to(dtype=q.dtype, device=q.device)
attention_out = self._masked_attention(q, k, v, attention_mask)
out.copy_(attention_out, non_blocking=True)
start_idx += prompt_len
def single_query_cached_kv_attention( def single_query_cached_kv_attention(
self, self,
output: torch.Tensor, output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
query: torch.Tensor, query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
key_cache: torch.Tensor, key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> None: ) -> None:
num_heads = value_cache.shape[1] num_heads = value_cache.shape[1]
...@@ -82,15 +92,18 @@ class OPTCacheFlowAttention(nn.Module): ...@@ -82,15 +92,18 @@ class OPTCacheFlowAttention(nn.Module):
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Prune out invalid tokens. # Pre-allocate the output tensor.
output = torch.empty_like(query)
# Prune out paddings if any.
query = query[:input_metadata.num_valid_tokens] query = query[:input_metadata.num_valid_tokens]
key = key[:input_metadata.num_valid_tokens] key = key[:input_metadata.num_valid_tokens]
value = value[:input_metadata.num_valid_tokens] value = value[:input_metadata.num_valid_tokens]
...@@ -101,18 +114,11 @@ class OPTCacheFlowAttention(nn.Module): ...@@ -101,18 +114,11 @@ class OPTCacheFlowAttention(nn.Module):
query = query.view(-1, num_heads, head_size) query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size) key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size) value = value.view(-1, num_heads, head_size)
output = output.view(-1, num_heads, head_size)
# Compute the attention op for prompts. # Compute the attention op for prompts.
output = torch.empty_like(query) self.multi_query_kv_attention(
start_idx = 0 output, query, key, value, input_metadata.prompt_lens)
for i in range(input_metadata.num_prompts):
prompt_len = input_metadata.prompt_lens[i]
out = output[start_idx:start_idx + prompt_len]
q = query[start_idx:start_idx + prompt_len]
k = key[start_idx:start_idx + prompt_len]
v = value[start_idx:start_idx + prompt_len]
self.multi_query_kv_attention(out, q, k, v)
start_idx += prompt_len
# Wait until the cache op is done. # Wait until the cache op is done.
if cache_event is not None: if cache_event is not None:
...@@ -124,6 +130,7 @@ class OPTCacheFlowAttention(nn.Module): ...@@ -124,6 +130,7 @@ class OPTCacheFlowAttention(nn.Module):
if input_metadata.num_generation_tokens > 0: if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens. # Compute the attention op for generation tokens.
start_idx = sum(input_metadata.prompt_lens)
self.single_query_cached_kv_attention( self.single_query_cached_kv_attention(
output[start_idx:], output[start_idx:],
query[start_idx:], query[start_idx:],
...@@ -132,4 +139,5 @@ class OPTCacheFlowAttention(nn.Module): ...@@ -132,4 +139,5 @@ class OPTCacheFlowAttention(nn.Module):
input_metadata) input_metadata)
# Reshape the output tensor. # Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return output.view(-1, num_heads * head_size) return output.view(-1, num_heads * head_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment