Commit 762fd1c3 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Refactor and annotate types for attention

parent 7f22f90e
from typing import Optional
from typing import List, Optional
import torch
import torch.nn as nn
......@@ -30,24 +30,34 @@ class OPTCacheFlowAttention(nn.Module):
def multi_query_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
prompt_lens: List[int],
) -> None:
# FIXME(woosuk): Replace this with a custom op call.
attention_mask = torch.triu(
torch.ones(query.shape[0], key.shape[0]), diagonal=1) * -1e5
attention_mask = attention_mask.to(dtype=query.dtype, device=query.device)
out = self._masked_attention(query, key, value, attention_mask)
output.copy_(out, non_blocking=True)
# FIXME(woosuk): Replace the following with a custom op.
start_idx = 0
for prompt_len in prompt_lens:
out = output[start_idx:start_idx + prompt_len]
q = query[start_idx:start_idx + prompt_len]
k = key[start_idx:start_idx + prompt_len]
v = value[start_idx:start_idx + prompt_len]
attention_mask = torch.triu(
torch.ones(q.shape[0], k.shape[0]), diagonal=1) * -1e5
attention_mask = attention_mask.to(dtype=q.dtype, device=q.device)
attention_out = self._masked_attention(q, k, v, attention_mask)
out.copy_(attention_out, non_blocking=True)
start_idx += prompt_len
def single_query_cached_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
input_metadata: InputMetadata,
) -> None:
num_heads = value_cache.shape[1]
......@@ -82,15 +92,18 @@ class OPTCacheFlowAttention(nn.Module):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Prune out invalid tokens.
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Pre-allocate the output tensor.
output = torch.empty_like(query)
# Prune out paddings if any.
query = query[:input_metadata.num_valid_tokens]
key = key[:input_metadata.num_valid_tokens]
value = value[:input_metadata.num_valid_tokens]
......@@ -101,18 +114,11 @@ class OPTCacheFlowAttention(nn.Module):
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
output = output.view(-1, num_heads, head_size)
# Compute the attention op for prompts.
output = torch.empty_like(query)
start_idx = 0
for i in range(input_metadata.num_prompts):
prompt_len = input_metadata.prompt_lens[i]
out = output[start_idx:start_idx + prompt_len]
q = query[start_idx:start_idx + prompt_len]
k = key[start_idx:start_idx + prompt_len]
v = value[start_idx:start_idx + prompt_len]
self.multi_query_kv_attention(out, q, k, v)
start_idx += prompt_len
self.multi_query_kv_attention(
output, query, key, value, input_metadata.prompt_lens)
# Wait until the cache op is done.
if cache_event is not None:
......@@ -124,6 +130,7 @@ class OPTCacheFlowAttention(nn.Module):
if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens.
start_idx = sum(input_metadata.prompt_lens)
self.single_query_cached_kv_attention(
output[start_idx:],
query[start_idx:],
......@@ -132,4 +139,5 @@ class OPTCacheFlowAttention(nn.Module):
input_metadata)
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return output.view(-1, num_heads * head_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment