Commit b7290868 authored by Casper Hansen's avatar Casper Hansen
Browse files

Use hf_rotary per default

parent 85430ddc
......@@ -122,9 +122,9 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
# Prints
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
context_tokens_per_second = n_context / context_time * batch_size
context_ms_per_token = (context_time*1000) / n_context * batch_size
context_ms_per_token = (context_time*1000) / n_context / batch_size
inference_tokens_per_second = n_generate / generation_time * batch_size
inference_ms_per_token = (generation_time*1000) / n_generate * batch_size
inference_ms_per_token = (generation_time*1000) / n_generate / batch_size
print(f"[======] Model summary: {model_path} [======]")
print(f"[*] Load time: {load_time:.2f} seconds")
......@@ -185,9 +185,6 @@ if __name__ == '__main__':
run_eval(args.model_path, args.quant_file, args.device,
args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
elif args.entry_type == 'speed':
if args.batch_size > 1 and not args.disable_fused_layers:
raise Exception('Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).')
run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context, args.batch_size, args.disable_fused_layers)
else:
raise Exception('--entry_type must be one of (search|quant|eval|speed)')
......@@ -2,6 +2,7 @@ import torch
import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, LlamaRotaryEmbedding
class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
......@@ -41,6 +42,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
# to the attention op.
query = query.contiguous()
key = key.contiguous()
awq_inference_engine.rotary_embedding_neox(
positions,
query,
......@@ -60,19 +62,25 @@ class QuantLlamaAttention(nn.Module):
qkv_proj,
o_proj,
dev,
max_new_tokens
max_new_tokens,
use_hf_rotary=True
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.use_hf_rotary = use_hf_rotary
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
if use_hf_rotary:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_new_tokens, device=dev)
else:
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
"""Input shape: Batch x Time x Channel"""
......@@ -84,13 +92,27 @@ class QuantLlamaAttention(nn.Module):
# This updates the query and key states in-place, saving VRAM.
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
if self.use_hf_rotary:
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
else:
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
del qkv_states
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
is_causal = past_key_value is None
kv_seq_len = q_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment