Commit e71181bd authored by Casper Hansen's avatar Casper Hansen
Browse files

Update shapes for cuda kernel

parent 88964968
......@@ -38,18 +38,28 @@ class QuantLlamaRotaryEmbedding(nn.Module):
key: torch.Tensor,
positions: torch.Tensor,
):
batch_size, seq_len, _ = query.shape
query = query.view(batch_size * seq_len, -1)
key = key.view(batch_size * seq_len, -1)
positions = positions.view(-1).to(query.device)
# Apply rotary embedding to the query and key before passing them
# to the attention op.
query = query.contiguous()
key = key.contiguous()
awq_inference_engine.rotary_embedding_neox(
awq_inference_engine.rotary_embedding(
positions,
query,
key,
self.dim,
self.cos_sin_cache,
True # is_neox
)
query = query.view(batch_size, seq_len, -1)
key = key.view(batch_size, seq_len, -1)
return query, key
class QuantLlamaAttention(nn.Module):
......@@ -88,12 +98,13 @@ class QuantLlamaAttention(nn.Module):
bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
# This updates the query and key states in-place, saving VRAM.
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
if self.use_hf_rotary:
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
# This updates the query and key states in-place, saving VRAM.
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
......@@ -106,10 +117,13 @@ class QuantLlamaAttention(nn.Module):
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
else:
query_states, key_states, value_states = qkv_states.chunk(chunks=3, dim=-1)
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
del qkv_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment