Commit ab7d68e7 authored by Casper Hansen's avatar Casper Hansen
Browse files

Correct input sizes

parent a4626828
...@@ -99,6 +99,7 @@ class LlamaFuser: ...@@ -99,6 +99,7 @@ class LlamaFuser:
attn = QuantLlamaAttention( attn = QuantLlamaAttention(
module.hidden_size, module.hidden_size,
module.num_heads, module.num_heads,
module.num_key_value_heads,
qkv_layer, qkv_layer,
module.o_proj, module.o_proj,
qkv_layer.qweight.device, qkv_layer.qweight.device,
......
...@@ -30,6 +30,7 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -30,6 +30,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
sin = freqs.sin() sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
# [max_position, rot_dim]
self.register_buffer("cos_sin_cache", cache.half(), persistent=False) self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward( def forward(
...@@ -38,11 +39,6 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -38,11 +39,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
key: torch.Tensor, key: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
): ):
batch_size, seq_len, _ = query.shape
query = query.view(batch_size * seq_len, -1)
key = key.view(batch_size * seq_len, -1)
positions = positions.view(-1).to(query.device)
# Apply rotary embedding to the query and key before passing them # Apply rotary embedding to the query and key before passing them
# to the attention op. # to the attention op.
query = query.contiguous() query = query.contiguous()
...@@ -57,9 +53,6 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -57,9 +53,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
True # is_neox True # is_neox
) )
query = query.view(batch_size, seq_len, -1)
key = key.view(batch_size, seq_len, -1)
return query, key return query, key
class QuantLlamaAttention(nn.Module): class QuantLlamaAttention(nn.Module):
...@@ -69,15 +62,17 @@ class QuantLlamaAttention(nn.Module): ...@@ -69,15 +62,17 @@ class QuantLlamaAttention(nn.Module):
self, self,
hidden_size, hidden_size,
num_heads, num_heads,
num_kv_heads,
qkv_proj, qkv_proj,
o_proj, o_proj,
dev, dev,
max_new_tokens, max_new_tokens,
use_hf_rotary=True use_hf_rotary=False
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_heads = num_heads self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_heads self.head_dim = hidden_size // num_heads
self.use_hf_rotary = use_hf_rotary self.use_hf_rotary = use_hf_rotary
...@@ -100,32 +95,44 @@ class QuantLlamaAttention(nn.Module): ...@@ -100,32 +95,44 @@ class QuantLlamaAttention(nn.Module):
qkv_states = self.qkv_proj(hidden_states) qkv_states = self.qkv_proj(hidden_states)
if self.use_hf_rotary: if self.use_hf_rotary:
# get qkv
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim) qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
query, key, value = torch.split(qkv_states, 1, dim=2)
del qkv_states
# This updates the query and key states in-place, saving VRAM. # reshape for hf rotary
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2) query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
else: else:
query_states, key_states, value_states = qkv_states.chunk(chunks=3, dim=-1) # get qkv
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids) query, key, value = qkv_states.chunk(chunks=3, dim=-1)
del qkv_states
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # [num_tokens, num_heads * head_size]
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_batch_size, query_len, _ = query.shape
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query = query.view(query_len*query_batch_size, self.num_heads * self.head_dim)
# [num_tokens, num_kv_heads * head_size]
key_batch_size, key_len, _ = key.shape
key = key.view(key_len*key_batch_size, self.num_kv_heads * self.head_dim)
del qkv_states # [num_tokens]
positions = position_ids.view(-1).to(query.device)
query, key = self.rotary_emb(query, key, positions)
query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
is_causal = past_key_value is None is_causal = past_key_value is None
...@@ -133,25 +140,25 @@ class QuantLlamaAttention(nn.Module): ...@@ -133,25 +140,25 @@ class QuantLlamaAttention(nn.Module):
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
value_states = value_states.to(key_states.device) value = value.to(key.device)
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2) key = torch.cat([past_key_value[0], key], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2) value = torch.cat([past_key_value[1], value], dim=2)
if use_cache: if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor # Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key_states = key_states.contiguous() key = key.contiguous()
value_states = value_states.contiguous() value = value.contiguous()
query_states = query_states.contiguous() query = query.contiguous()
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key, value) if use_cache else None
# with torch.backends.cuda.sdp_kernel(enable_math=False): # with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal) attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
del query_states, key_states, value_states del query, key, value
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment