Commit d4bc1a4d authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Add unoptimized OPT Attention

parent b56b6ca0
from typing import Optional, Tuple
import torch
import torch.nn as nn
import xformers.ops as xops
from cacheflow import ops
from cacheflow.models import InputMetadata
class OPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = scale
# Shape-agnostic attention mask.
self.attention_mask = xops.LowerTriangularMask()
def multi_query_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> None:
out = xops.memory_efficient_attention(
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
# FIXME(woosuk): Directly write the attention output.
output.copy_(out, non_blocking=True)
def single_query_cached_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
) -> None:
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
block_size = value_cache.shape[2]
block_tables = input_metadata.block_tables
# FIXME(woosuk): Replace the following with a custom op.
for i in range(input_metadata.num_generation_tokens):
q = query[i]
block_table = block_tables[i]
context_len = int(input_metadata.context_lens[i])
keys = []
for j in range(context_len):
block_number = block_table[j // block_size]
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.view(num_heads, head_size)
keys.append(k)
keys = torch.stack(keys, dim=-1)
logits = q @ keys
attention_weights = torch.softmax(logits, dim=-1)
values = []
for j in range(context_len):
block_number = block_table[j // block_size]
block_offset = j % block_size
v = value_cache[block_number, :, block_offset, :]
values.append(v)
values = torch.stack(values, dim=-1)
out = attention_weights @ values
output[i].copy_(out, non_blocking=True)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Reshape the input tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
# Compute the attention op for prompts.
output = torch.empty_like(query)
start_idx = 0
for i in range(input_metadata.num_prompts):
prompt_len = input_metadata.prompt_lens[i]
out = output[start_idx:start_idx + prompt_len]
q = query[start_idx:start_idx + prompt_len]
k = key[start_idx:start_idx + prompt_len]
v = value[start_idx:start_idx + prompt_len]
self.multi_query_kv_attention(out, q, k, v)
start_idx += prompt_len
# Wait until the cache op is done.
if cache_event is not None:
cache_event.wait()
# Reshape the keys and values and store them in the cache.
ops.reshape_and_cache(
key, value, key_cache, value_cache, input_metadata.slot_mapping)
if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[start_idx:],
query[start_idx:],
key_cache,
value_cache,
input_metadata)
# Reshape the output tensor.
return output.view(-1, num_heads * head_size)
"""1D OPT model compatible with HuggingFace weights."""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import OPTConfig
from transformers import PreTrainedModel
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding):
......@@ -31,17 +39,27 @@ class OPTAttention(nn.Module):
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim**-0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
q = self.q_proj(hidden_states) * self.scaling
self.attn = OPTCacheFlowAttention(scale=self.scaling)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# TODO
attn_output = None
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output = self.out_proj(attn_output)
return output
......@@ -66,13 +84,23 @@ class OPTDecoderLayer(nn.Module):
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
......@@ -145,6 +173,9 @@ class OPTDecoder(OPTPreTrainedModel):
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions)
......@@ -153,8 +184,14 @@ class OPTDecoder(OPTPreTrainedModel):
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
for layer in self.layers:
hidden_states = layer(hidden_states)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
......@@ -175,8 +212,12 @@ class OPTModel(OPTPreTrainedModel):
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
return self.decoder(input_ids, positions)
return self.decoder(
input_ids, positions, kv_caches, input_metadata, cache_events)
class OPTForCausalLM(OPTPreTrainedModel):
......@@ -185,9 +226,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = OPTModel(config)
# the lm_head weight is automatically tied to the embed tokens weight
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
self.sampler = Sampler(embedding=self.lm_head.weight)
# Initialize weights and apply final processing
self.post_init()
......@@ -196,7 +237,11 @@ class OPTForCausalLM(OPTPreTrainedModel):
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
) -> torch.Tensor:
hidden_states = self.model.decoder(input_ids, positions)
logits = self.lm_head(hidden_states).contiguous()
return logits
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, Tuple[int, int]]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(hidden_states, input_metadata)
return next_tokens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment