Commit e71181bd authored by Casper Hansen's avatar Casper Hansen
Browse files

Update shapes for cuda kernel

parent 88964968
...@@ -38,18 +38,28 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -38,18 +38,28 @@ class QuantLlamaRotaryEmbedding(nn.Module):
key: torch.Tensor, key: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
): ):
batch_size, seq_len, _ = query.shape
query = query.view(batch_size * seq_len, -1)
key = key.view(batch_size * seq_len, -1)
positions = positions.view(-1).to(query.device)
# Apply rotary embedding to the query and key before passing them # Apply rotary embedding to the query and key before passing them
# to the attention op. # to the attention op.
query = query.contiguous() query = query.contiguous()
key = key.contiguous() key = key.contiguous()
awq_inference_engine.rotary_embedding_neox( awq_inference_engine.rotary_embedding(
positions, positions,
query, query,
key, key,
self.dim, self.dim,
self.cos_sin_cache, self.cos_sin_cache,
True # is_neox
) )
query = query.view(batch_size, seq_len, -1)
key = key.view(batch_size, seq_len, -1)
return query, key return query, key
class QuantLlamaAttention(nn.Module): class QuantLlamaAttention(nn.Module):
...@@ -88,12 +98,13 @@ class QuantLlamaAttention(nn.Module): ...@@ -88,12 +98,13 @@ class QuantLlamaAttention(nn.Module):
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states) qkv_states = self.qkv_proj(hidden_states)
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
# This updates the query and key states in-place, saving VRAM.
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
if self.use_hf_rotary: if self.use_hf_rotary:
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
# This updates the query and key states in-place, saving VRAM.
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
...@@ -106,10 +117,13 @@ class QuantLlamaAttention(nn.Module): ...@@ -106,10 +117,13 @@ class QuantLlamaAttention(nn.Module):
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
else: else:
query_states, key_states, value_states = qkv_states.chunk(chunks=3, dim=-1)
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids) query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
del qkv_states del qkv_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment