Unverified Commit 727172e9 authored by Casper's avatar Casper Committed by GitHub
Browse files

Fused rope theta (#270)


Co-authored-by: default avatarCasper Hansen <casperbh96@gmail.com>
parent 5b9f3c47
...@@ -118,7 +118,8 @@ class LlamaFuser: ...@@ -118,7 +118,8 @@ class LlamaFuser:
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
max_seq_len=self.model.config.max_new_tokens max_seq_len=self.model.config.max_new_tokens,
rope_theta=self.model.config.rope_theta
)) ))
self.model.model = LlamaLikeModel( self.model.model = LlamaLikeModel(
......
...@@ -18,8 +18,7 @@ class MixtralAWQForCausalLM(BaseAWQForCausalLM): ...@@ -18,8 +18,7 @@ class MixtralAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def fuse_layers(model: OldMixtralForCausalLM): def fuse_layers(model: OldMixtralForCausalLM):
fuser = MixtralFuser(model) fuser = MixtralFuser(model)
# TODO: Fix perplexity on fusing Mixtral fuser.fuse_transformer()
#fuser.fuse_transformer()
@staticmethod @staticmethod
def get_model_layers(model: OldMixtralForCausalLM): def get_model_layers(model: OldMixtralForCausalLM):
...@@ -125,7 +124,8 @@ class MixtralFuser: ...@@ -125,7 +124,8 @@ class MixtralFuser:
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
max_seq_len=self.model.config.max_new_tokens max_seq_len=self.model.config.max_new_tokens,
rope_theta=self.model.config.rope_theta
)) ))
self.model.model = MixtralModel( self.model.model = MixtralModel(
......
...@@ -113,7 +113,8 @@ class YiFuser: ...@@ -113,7 +113,8 @@ class YiFuser:
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
max_seq_len=self.model.config.max_new_tokens max_seq_len=self.model.config.max_new_tokens,
rope_theta=self.model.config.rope_theta
)) ))
self.model.model = LlamaLikeModel( self.model.model = LlamaLikeModel(
......
...@@ -23,11 +23,13 @@ if HF_NEW_CACHE_FORMAT: ...@@ -23,11 +23,13 @@ if HF_NEW_CACHE_FORMAT:
class RoPE(nn.Module): class RoPE(nn.Module):
def __init__(self, hidden_size, n_heads, max_seq_len, device): def __init__(self, hidden_size, n_heads, max_seq_len, device, rope_theta):
super(RoPE, self).__init__() super(RoPE, self).__init__()
self.freqs_cis = nn.Parameter( self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(hidden_size // n_heads, max_seq_len * 2).to(device), self.precompute_freqs_cis(
hidden_size // n_heads, max_seq_len * 2, rope_theta
).to(device),
requires_grad=False requires_grad=False
) )
...@@ -97,7 +99,7 @@ class ALiBi(nn.Module): ...@@ -97,7 +99,7 @@ class ALiBi(nn.Module):
class QuantAttentionFused(nn.Module): class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None): use_alibi=False, attention_shapes=None, rope_theta=10000):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.n_heads = n_heads self.n_heads = n_heads
...@@ -111,6 +113,7 @@ class QuantAttentionFused(nn.Module): ...@@ -111,6 +113,7 @@ class QuantAttentionFused(nn.Module):
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.is_hf_transformers = False self.is_hf_transformers = False
self.rope_theta = rope_theta
# attention shapes for self attention # attention shapes for self attention
self.attention_shapes = get_attention_shapes( self.attention_shapes = get_attention_shapes(
...@@ -127,7 +130,7 @@ class QuantAttentionFused(nn.Module): ...@@ -127,7 +130,7 @@ class QuantAttentionFused(nn.Module):
self.is_neox = False self.is_neox = False
else: else:
self.alibi = None self.alibi = None
self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev) self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev, rope_theta)
self.rotary_dim = self.head_dim self.rotary_dim = self.head_dim
self.is_neox = True self.is_neox = True
...@@ -221,7 +224,7 @@ class QuantAttentionFused(nn.Module): ...@@ -221,7 +224,7 @@ class QuantAttentionFused(nn.Module):
alibi_slopes, # alibi slopes alibi_slopes, # alibi slopes
self.start_pos, # timestep self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base self.rope_theta, # rotary embedding base
self.is_neox, # is neox self.is_neox, # is neox
) )
attention_weight = attention_weight.reshape(bsz, 1, -1) attention_weight = attention_weight.reshape(bsz, 1, -1)
......
...@@ -5,7 +5,7 @@ from awq.modules.fused.attn import QuantAttentionFused ...@@ -5,7 +5,7 @@ from awq.modules.fused.attn import QuantAttentionFused
class MixtralBlock(nn.Module): class MixtralBlock(nn.Module):
def __init__( def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
moe, norm_1, norm_2, dev, max_seq_len moe, norm_1, norm_2, dev, max_seq_len, rope_theta
): ):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
...@@ -14,7 +14,7 @@ class MixtralBlock(nn.Module): ...@@ -14,7 +14,7 @@ class MixtralBlock(nn.Module):
self.norm_1 = norm_1.to(dev) self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False dev=dev, max_seq_len=max_seq_len, use_alibi=False, rope_theta=rope_theta
).to(dev) ).to(dev)
self.norm_2 = norm_2.to(dev) self.norm_2 = norm_2.to(dev)
self.moe = moe self.moe = moe
...@@ -41,7 +41,10 @@ class LlamaLikeBlock(nn.Module): ...@@ -41,7 +41,10 @@ class LlamaLikeBlock(nn.Module):
LlamaLikeBlock is intended to be reused across blocks that have LlamaLikeBlock is intended to be reused across blocks that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila. an architecture that closely resembles Llama, e.g. Mistral and Aquila.
""" """
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, mlp, norm_1, norm_2, dev, max_seq_len): def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
mlp, norm_1, norm_2, dev, max_seq_len, rope_theta
):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
self.n_kv_heads = n_kv_heads self.n_kv_heads = n_kv_heads
...@@ -49,7 +52,7 @@ class LlamaLikeBlock(nn.Module): ...@@ -49,7 +52,7 @@ class LlamaLikeBlock(nn.Module):
self.norm_1 = norm_1.to(dev) self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False dev=dev, max_seq_len=max_seq_len, use_alibi=False, rope_theta=rope_theta
).to(dev) ).to(dev)
self.norm_2 = norm_2.to(dev) self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev) self.mlp = mlp.to(dev)
......
import transformers
import torch
from lm_eval.base import BaseLM
import fnmatch
import logging
class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, device, batch_size=1, max_length=-1):
super().__init__()
assert isinstance(batch_size, int)
self.model_name = model_name
self.model = model.to(device)
self.model.eval()
self.tokenizer = tokenizer
# assert isinstance(self.tokenizer, (
# transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
# transformers.T5Tokenizer, transformers.T5TokenizerFast,
# )), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size
self._batch_size = batch_size
self._max_length = max_length
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
if self._max_length != -1:
return self._max_length
if hasattr(self.model.config, 'n_ctx'):
return self.model.config.n_ctx
elif hasattr(self.model.config, 'max_position_embeddings'):
return self.model.config.max_position_embeddings
elif hasattr(self.model.config, 'n_positions'):
return self.model.config.n_positions
elif 'bloom' in self.model_name:
return 2048
elif 'llama' in self.model_name:
return 2048 # TODO: did not check this
elif 'mpt' in self.model_name:
return 2048
elif 'falcon' in self.model_name:
return 2048
else:
logging.debug(self.model.config)
raise NotImplementedError
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
return self._batch_size
@property
def device(self):
return "cuda"
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
if isinstance(self.model, transformers.models.t5.modeling_t5.T5ForConditionalGeneration):
dec_inps = torch.cat(
[
torch.tensor(
self.model.generation_config.decoder_start_token_id,
)
.tile(len(inps), 1)
.to(inps),
inps,
],
dim=1,
)
kwargs = {"decoder_input_ids": dec_inps,}
else:
kwargs = {}
out = self.model(inps, **kwargs)[0]
if "opt" in self.model_name: # there are a few extra tokens in opt, which we should omit
return out[:, :, :50257]
else:
return out # [:, :, :self.tokenizer.vocab_size]
def _model_generate(self, context, max_length, eos_token_id):
return self.model.generate(
context,
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False
)
...@@ -2,7 +2,6 @@ import argparse ...@@ -2,7 +2,6 @@ import argparse
from lm_eval import evaluator from lm_eval import evaluator
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from awq.utils.eval_utils import evaluate_perplexity from awq.utils.eval_utils import evaluate_perplexity
def run_eval( def run_eval(
...@@ -26,11 +25,9 @@ def run_eval( ...@@ -26,11 +25,9 @@ def run_eval(
evaluate_perplexity(model.model, tokenizer) evaluate_perplexity(model.model, tokenizer)
else: else:
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size)
# Evaluate perplexity of quantized model # Evaluate perplexity of quantized model
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
model=lm_eval_model, model=model,
tasks=tasks, tasks=tasks,
batch_size=task_batch_size, batch_size=task_batch_size,
no_cache=True, no_cache=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment