Unverified Commit 727172e9 authored by Casper's avatar Casper Committed by GitHub
Browse files

Fused rope theta (#270)


Co-authored-by: default avatarCasper Hansen <casperbh96@gmail.com>
parent 5b9f3c47
......@@ -118,7 +118,8 @@ class LlamaFuser:
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
max_seq_len=self.model.config.max_new_tokens,
rope_theta=self.model.config.rope_theta
))
self.model.model = LlamaLikeModel(
......
......@@ -18,8 +18,7 @@ class MixtralAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod
def fuse_layers(model: OldMixtralForCausalLM):
fuser = MixtralFuser(model)
# TODO: Fix perplexity on fusing Mixtral
#fuser.fuse_transformer()
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldMixtralForCausalLM):
......@@ -125,7 +124,8 @@ class MixtralFuser:
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
max_seq_len=self.model.config.max_new_tokens,
rope_theta=self.model.config.rope_theta
))
self.model.model = MixtralModel(
......
......@@ -113,7 +113,8 @@ class YiFuser:
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
max_seq_len=self.model.config.max_new_tokens,
rope_theta=self.model.config.rope_theta
))
self.model.model = LlamaLikeModel(
......
......@@ -23,11 +23,13 @@ if HF_NEW_CACHE_FORMAT:
class RoPE(nn.Module):
def __init__(self, hidden_size, n_heads, max_seq_len, device):
def __init__(self, hidden_size, n_heads, max_seq_len, device, rope_theta):
super(RoPE, self).__init__()
self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(hidden_size // n_heads, max_seq_len * 2).to(device),
self.precompute_freqs_cis(
hidden_size // n_heads, max_seq_len * 2, rope_theta
).to(device),
requires_grad=False
)
......@@ -97,7 +99,7 @@ class ALiBi(nn.Module):
class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None):
use_alibi=False, attention_shapes=None, rope_theta=10000):
super().__init__()
self.hidden_size = hidden_size
self.n_heads = n_heads
......@@ -111,6 +113,7 @@ class QuantAttentionFused(nn.Module):
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len
self.is_hf_transformers = False
self.rope_theta = rope_theta
# attention shapes for self attention
self.attention_shapes = get_attention_shapes(
......@@ -127,7 +130,7 @@ class QuantAttentionFused(nn.Module):
self.is_neox = False
else:
self.alibi = None
self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev)
self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev, rope_theta)
self.rotary_dim = self.head_dim
self.is_neox = True
......@@ -221,7 +224,7 @@ class QuantAttentionFused(nn.Module):
alibi_slopes, # alibi slopes
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base
self.rope_theta, # rotary embedding base
self.is_neox, # is neox
)
attention_weight = attention_weight.reshape(bsz, 1, -1)
......
......@@ -5,7 +5,7 @@ from awq.modules.fused.attn import QuantAttentionFused
class MixtralBlock(nn.Module):
def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
moe, norm_1, norm_2, dev, max_seq_len
moe, norm_1, norm_2, dev, max_seq_len, rope_theta
):
super().__init__()
self.n_heads = n_heads
......@@ -14,7 +14,7 @@ class MixtralBlock(nn.Module):
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False
dev=dev, max_seq_len=max_seq_len, use_alibi=False, rope_theta=rope_theta
).to(dev)
self.norm_2 = norm_2.to(dev)
self.moe = moe
......@@ -41,7 +41,10 @@ class LlamaLikeBlock(nn.Module):
LlamaLikeBlock is intended to be reused across blocks that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, mlp, norm_1, norm_2, dev, max_seq_len):
def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
mlp, norm_1, norm_2, dev, max_seq_len, rope_theta
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
......@@ -49,7 +52,7 @@ class LlamaLikeBlock(nn.Module):
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False
dev=dev, max_seq_len=max_seq_len, use_alibi=False, rope_theta=rope_theta
).to(dev)
self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
......
import transformers
import torch
from lm_eval.base import BaseLM
import fnmatch
import logging
class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, device, batch_size=1, max_length=-1):
super().__init__()
assert isinstance(batch_size, int)
self.model_name = model_name
self.model = model.to(device)
self.model.eval()
self.tokenizer = tokenizer
# assert isinstance(self.tokenizer, (
# transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
# transformers.T5Tokenizer, transformers.T5TokenizerFast,
# )), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size
self._batch_size = batch_size
self._max_length = max_length
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
if self._max_length != -1:
return self._max_length
if hasattr(self.model.config, 'n_ctx'):
return self.model.config.n_ctx
elif hasattr(self.model.config, 'max_position_embeddings'):
return self.model.config.max_position_embeddings
elif hasattr(self.model.config, 'n_positions'):
return self.model.config.n_positions
elif 'bloom' in self.model_name:
return 2048
elif 'llama' in self.model_name:
return 2048 # TODO: did not check this
elif 'mpt' in self.model_name:
return 2048
elif 'falcon' in self.model_name:
return 2048
else:
logging.debug(self.model.config)
raise NotImplementedError
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
return self._batch_size
@property
def device(self):
return "cuda"
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
if isinstance(self.model, transformers.models.t5.modeling_t5.T5ForConditionalGeneration):
dec_inps = torch.cat(
[
torch.tensor(
self.model.generation_config.decoder_start_token_id,
)
.tile(len(inps), 1)
.to(inps),
inps,
],
dim=1,
)
kwargs = {"decoder_input_ids": dec_inps,}
else:
kwargs = {}
out = self.model(inps, **kwargs)[0]
if "opt" in self.model_name: # there are a few extra tokens in opt, which we should omit
return out[:, :, :50257]
else:
return out # [:, :, :self.tokenizer.vocab_size]
def _model_generate(self, context, max_length, eos_token_id):
return self.model.generate(
context,
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False
)
......@@ -2,7 +2,6 @@ import argparse
from lm_eval import evaluator
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from awq.utils.eval_utils import evaluate_perplexity
def run_eval(
......@@ -26,11 +25,9 @@ def run_eval(
evaluate_perplexity(model.model, tokenizer)
else:
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size)
# Evaluate perplexity of quantized model
results = evaluator.simple_evaluate(
model=lm_eval_model,
model=model,
tasks=tasks,
batch_size=task_batch_size,
no_cache=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment