Commit ab7d68e7 authored by Casper Hansen's avatar Casper Hansen
Browse files

Correct input sizes

parent a4626828
......@@ -99,6 +99,7 @@ class LlamaFuser:
attn = QuantLlamaAttention(
module.hidden_size,
module.num_heads,
module.num_key_value_heads,
qkv_layer,
module.o_proj,
qkv_layer.qweight.device,
......
......@@ -30,6 +30,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
# [max_position, rot_dim]
self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward(
......@@ -38,11 +39,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
key: torch.Tensor,
positions: torch.Tensor,
):
batch_size, seq_len, _ = query.shape
query = query.view(batch_size * seq_len, -1)
key = key.view(batch_size * seq_len, -1)
positions = positions.view(-1).to(query.device)
# Apply rotary embedding to the query and key before passing them
# to the attention op.
query = query.contiguous()
......@@ -57,9 +53,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
True # is_neox
)
query = query.view(batch_size, seq_len, -1)
key = key.view(batch_size, seq_len, -1)
return query, key
class QuantLlamaAttention(nn.Module):
......@@ -69,15 +62,17 @@ class QuantLlamaAttention(nn.Module):
self,
hidden_size,
num_heads,
num_kv_heads,
qkv_proj,
o_proj,
dev,
max_new_tokens,
use_hf_rotary=True
use_hf_rotary=False
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_heads
self.use_hf_rotary = use_hf_rotary
......@@ -100,32 +95,44 @@ class QuantLlamaAttention(nn.Module):
qkv_states = self.qkv_proj(hidden_states)
if self.use_hf_rotary:
# get qkv
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
# This updates the query and key states in-place, saving VRAM.
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
query, key, value = torch.split(qkv_states, 1, dim=2)
del qkv_states
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# reshape for hf rotary
query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
kv_seq_len = key.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = self.rotary_emb(value, seq_len=kv_seq_len)
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
else:
query_states, key_states, value_states = qkv_states.chunk(chunks=3, dim=-1)
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
# get qkv
query, key, value = qkv_states.chunk(chunks=3, dim=-1)
del qkv_states
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
del qkv_states
# [num_tokens, num_heads * head_size]
query_batch_size, query_len, _ = query.shape
query = query.view(query_len*query_batch_size, self.num_heads * self.head_dim)
# [num_tokens, num_kv_heads * head_size]
key_batch_size, key_len, _ = key.shape
key = key.view(key_len*key_batch_size, self.num_kv_heads * self.head_dim)
# [num_tokens]
positions = position_ids.view(-1).to(query.device)
query, key = self.rotary_emb(query, key, positions)
query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
is_causal = past_key_value is None
......@@ -133,25 +140,25 @@ class QuantLlamaAttention(nn.Module):
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
value_states = value_states.to(key_states.device)
value = value.to(key.device)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
key = torch.cat([past_key_value[0], key], dim=2)
value = torch.cat([past_key_value[1], value], dim=2)
if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
# Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key_states = key_states.contiguous()
value_states = value_states.contiguous()
query_states = query_states.contiguous()
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
past_key_value = (key_states, value_states) if use_cache else None
past_key_value = (key, value) if use_cache else None
# with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)
del query_states, key_states, value_states
attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
del query, key, value
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment