Commit b7290868 authored by Casper Hansen's avatar Casper Hansen
Browse files

Use hf_rotary per default

parent 85430ddc
...@@ -122,9 +122,9 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat ...@@ -122,9 +122,9 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
# Prints # Prints
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2) memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
context_tokens_per_second = n_context / context_time * batch_size context_tokens_per_second = n_context / context_time * batch_size
context_ms_per_token = (context_time*1000) / n_context * batch_size context_ms_per_token = (context_time*1000) / n_context / batch_size
inference_tokens_per_second = n_generate / generation_time * batch_size inference_tokens_per_second = n_generate / generation_time * batch_size
inference_ms_per_token = (generation_time*1000) / n_generate * batch_size inference_ms_per_token = (generation_time*1000) / n_generate / batch_size
print(f"[======] Model summary: {model_path} [======]") print(f"[======] Model summary: {model_path} [======]")
print(f"[*] Load time: {load_time:.2f} seconds") print(f"[*] Load time: {load_time:.2f} seconds")
...@@ -185,9 +185,6 @@ if __name__ == '__main__': ...@@ -185,9 +185,6 @@ if __name__ == '__main__':
run_eval(args.model_path, args.quant_file, args.device, run_eval(args.model_path, args.quant_file, args.device,
args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained) args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
elif args.entry_type == 'speed': elif args.entry_type == 'speed':
if args.batch_size > 1 and not args.disable_fused_layers:
raise Exception('Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).')
run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context, args.batch_size, args.disable_fused_layers) run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context, args.batch_size, args.disable_fused_layers)
else: else:
raise Exception('--entry_type must be one of (search|quant|eval|speed)') raise Exception('--entry_type must be one of (search|quant|eval|speed)')
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import awq_inference_engine import awq_inference_engine
from torch.nn import functional as F from torch.nn import functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, LlamaRotaryEmbedding
class QuantLlamaRotaryEmbedding(nn.Module): class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
...@@ -41,6 +42,7 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -41,6 +42,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
# to the attention op. # to the attention op.
query = query.contiguous() query = query.contiguous()
key = key.contiguous() key = key.contiguous()
awq_inference_engine.rotary_embedding_neox( awq_inference_engine.rotary_embedding_neox(
positions, positions,
query, query,
...@@ -60,18 +62,24 @@ class QuantLlamaAttention(nn.Module): ...@@ -60,18 +62,24 @@ class QuantLlamaAttention(nn.Module):
qkv_proj, qkv_proj,
o_proj, o_proj,
dev, dev,
max_new_tokens max_new_tokens,
use_hf_rotary=True
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = hidden_size // num_heads self.head_dim = hidden_size // num_heads
self.use_hf_rotary = use_hf_rotary
if (self.head_dim * num_heads) != self.hidden_size: if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads}).") f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj self.qkv_proj = qkv_proj
self.o_proj = o_proj self.o_proj = o_proj
if use_hf_rotary:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_new_tokens, device=dev)
else:
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev) self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False): def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
...@@ -84,13 +92,27 @@ class QuantLlamaAttention(nn.Module): ...@@ -84,13 +92,27 @@ class QuantLlamaAttention(nn.Module):
# This updates the query and key states in-place, saving VRAM. # This updates the query and key states in-place, saving VRAM.
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2) query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
del qkv_states if self.use_hf_rotary:
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
else:
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
del qkv_states
is_causal = past_key_value is None is_causal = past_key_value is None
kv_seq_len = q_len kv_seq_len = q_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment