Unverified Commit bcaa8a36 authored by Casper's avatar Casper Committed by GitHub
Browse files

v0.2.0 (#330)


Co-authored-by: default avatarjinz2014 <7799920+jinz2014@users.noreply.github.com>
Co-authored-by: default avatarJin Z <5zj@cousteau.ftpn.ornl.gov>
parent c69d3b65
...@@ -91,7 +91,7 @@ jobs: ...@@ -91,7 +91,7 @@ jobs:
# Install torch # Install torch
$cudaVersion = $env:CUDA_VERSION.Replace('.', '') $cudaVersion = $env:CUDA_VERSION.Replace('.', '')
$cudaVersionPytorch = $cudaVersion.Substring(0, $cudaVersion.Length - 1) $cudaVersionPytorch = $cudaVersion.Substring(0, $cudaVersion.Length - 1)
if ([int]$cudaVersionPytorch -gt 118) { $pytorchVersion = "torch==2.1.0" } else {$pytorchVersion = "torch==2.0.1"} if ([int]$cudaVersionPytorch -gt 118) { $pytorchVersion = "torch==2.2.0" } else {$pytorchVersion = "torch==2.0.1"}
python -m pip install --upgrade --no-cache-dir $pytorchVersion+cu$cudaVersionPytorch --index-url https://download.pytorch.org/whl/cu$cudaVersionPytorch python -m pip install --upgrade --no-cache-dir $pytorchVersion+cu$cudaVersionPytorch --index-url https://download.pytorch.org/whl/cu$cudaVersionPytorch
python -m pip install build setuptools wheel ninja python -m pip install build setuptools wheel ninja
......
name: Documentation
on:
push:
branches:
- main
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Git Credentials
run: |
git config user.name github-actions[bot]
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
- uses: actions/setup-python@v4
with:
python-version: 3.x
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v3
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-docs
- run: pip install mkdocstrings-python mkdocs-material griffe-typingdoc
- run: mkdocs gh-deploy --force
\ No newline at end of file
...@@ -70,33 +70,6 @@ All three methods will install the latest and correct kernels for your system fr ...@@ -70,33 +70,6 @@ All three methods will install the latest and correct kernels for your system fr
If your system is not supported (i.e. not on the release page), you can build the kernels yourself by following the instructions in [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases) and then install AutoAWQ from source. If your system is not supported (i.e. not on the release page), you can build the kernels yourself by following the instructions in [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases) and then install AutoAWQ from source.
## Supported models
The detailed support list:
| Models | Sizes |
| -------- | --------------------------- |
| LLaMA-2 | 7B/13B/70B |
| LLaMA | 7B/13B/30B/65B |
| Mistral | 7B |
| Vicuna | 7B/13B |
| MPT | 7B/30B |
| Falcon | 7B/40B |
| OPT | 125m/1.3B/2.7B/6.7B/13B/30B |
| Bloom | 560m/3B/7B/ |
| GPTJ | 6.7B |
| Aquila | 7B |
| Aquila2 | 7B/34B |
| Yi | 6B/34B |
| Qwen | 1.8B/7B/14B/72B |
| BigCode | 1B/7B/15B |
| GPT NeoX | 20B |
| GPT-J | 6B |
| LLaVa | 7B/13B |
| Mixtral | 8x7B |
| Baichuan | 7B/13B |
| QWen | 1.8B/7B/14/72B |
## Usage ## Usage
Under examples, you can find examples of how to quantize, run inference, and benchmark AutoAWQ models. Under examples, you can find examples of how to quantize, run inference, and benchmark AutoAWQ models.
...@@ -122,7 +95,7 @@ Fused modules are a large part of the speedup you get from AutoAWQ. The idea is ...@@ -122,7 +95,7 @@ Fused modules are a large part of the speedup you get from AutoAWQ. The idea is
- Fused modules are activated when you use `fuse_layers=True`. - Fused modules are activated when you use `fuse_layers=True`.
- A custom cache is implemented. It preallocates based on batch size and sequence length. - A custom cache is implemented. It preallocates based on batch size and sequence length.
- You cannot change the sequence length after you have created your model. - You cannot change the sequence length after you have created your model.
- Reference: `AutoAWQForCausalLM.from_quantized(max_new_tokens=seq_len, batch_size=batch_size)` - Reference: `AutoAWQForCausalLM.from_quantized(max_seq_len=seq_len, batch_size=batch_size)`
- The main accelerator in the fused modules comes from FasterTransformer, which is only compatible with Linux. - The main accelerator in the fused modules comes from FasterTransformer, which is only compatible with Linux.
- The `past_key_values` from `model.generate()` are only dummy values, so they cannot be used after generation. - The `past_key_values` from `model.generate()` are only dummy values, so they cannot be used after generation.
...@@ -194,7 +167,7 @@ tokens = tokenizer( ...@@ -194,7 +167,7 @@ tokens = tokenizer(
generation_output = model.generate( generation_output = model.generate(
tokens, tokens,
streamer=streamer, streamer=streamer,
max_new_tokens=512 max_seq_len=512
) )
``` ```
......
__version__ = "0.1.8" __version__ = "0.2.0"
from awq.models.auto import AutoAWQForCausalLM from awq.models.auto import AutoAWQForCausalLM
...@@ -9,21 +9,23 @@ from lm_eval.tasks import initialize_tasks ...@@ -9,21 +9,23 @@ from lm_eval.tasks import initialize_tasks
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.whisper.english_normalizer import BasicTextNormalizer from transformers.models.whisper.english_normalizer import BasicTextNormalizer
def get_device(): def get_device():
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
return 'mps' return "mps"
elif torch.cuda.is_available(): elif torch.cuda.is_available():
return 'cuda:0' return "cuda:0"
else: else:
return 'cpu' return "cpu"
def evaluate_perplexity(model, tokenizer): def evaluate_perplexity(model, tokenizer):
def _perplexity(nlls, n_samples, seqlen): def _perplexity(nlls, n_samples, seqlen):
return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen)) return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen))
# load and prepare dataset # load and prepare dataset
data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
data = tokenizer("\n\n".join(data['text']), return_tensors='pt') data = tokenizer("\n\n".join(data["text"]), return_tensors="pt")
data = data.input_ids.to(model.device) data = data.input_ids.to(model.device)
seqlen = 2048 seqlen = 2048
...@@ -34,25 +36,28 @@ def evaluate_perplexity(model, tokenizer): ...@@ -34,25 +36,28 @@ def evaluate_perplexity(model, tokenizer):
with tqdm(range(n_samples), desc="Perplexity -") as progress_bar: with tqdm(range(n_samples), desc="Perplexity -") as progress_bar:
for i in progress_bar: for i in progress_bar:
start_index = (i * seqlen) start_index = i * seqlen
end_index = ((i + 1) * seqlen) end_index = (i + 1) * seqlen
batch = data[:, start_index:end_index].to(model.device) batch = data[:, start_index:end_index].to(model.device)
with torch.no_grad(): with torch.no_grad():
logits = model(batch).logits logits = model(batch).logits
shift_logits = logits[:, :-1, :].contiguous().float() shift_logits = logits[:, :-1, :].contiguous().float()
shift_labels = data[:, start_index:end_index][:, 1:] shift_labels = data[:, start_index:end_index][:, 1:]
loss_fct = nn.CrossEntropyLoss() loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
neg_log_likelihood = loss.float() * seqlen neg_log_likelihood = loss.float() * seqlen
nlls.append(neg_log_likelihood) nlls.append(neg_log_likelihood)
curr_ppl = _perplexity(nlls, i+1, seqlen) curr_ppl = _perplexity(nlls, i + 1, seqlen)
progress_bar.set_description(f"Perplexity {curr_ppl:.3f}") progress_bar.set_description(f"Perplexity {curr_ppl:.3f}")
ppl = _perplexity(nlls, n_samples, seqlen) ppl = _perplexity(nlls, n_samples, seqlen)
return ppl.item() return ppl.item()
def eval_librispeech(model_id, num_samples=100, batch_size=4): def eval_librispeech(model_id, num_samples=100, batch_size=4):
try: try:
import jiwer, librosa, soundfile import jiwer, librosa, soundfile
...@@ -79,7 +84,8 @@ def eval_librispeech(model_id, num_samples=100, batch_size=4): ...@@ -79,7 +84,8 @@ def eval_librispeech(model_id, num_samples=100, batch_size=4):
texts = [] texts = []
audio = [] audio = []
for i, data in tqdm(enumerate(dataset), total=num_samples, desc="Loading dataset"): for i, data in tqdm(enumerate(dataset), total=num_samples, desc="Loading dataset"):
if len(audio) == num_samples: break if len(audio) == num_samples:
break
audio.append(data["audio"]) audio.append(data["audio"])
texts.append(data["text"]) texts.append(data["text"])
...@@ -88,8 +94,8 @@ def eval_librispeech(model_id, num_samples=100, batch_size=4): ...@@ -88,8 +94,8 @@ def eval_librispeech(model_id, num_samples=100, batch_size=4):
with tqdm(range(0, num_samples, batch_size), desc="Word Error Rate: -") as pbar: with tqdm(range(0, num_samples, batch_size), desc="Word Error Rate: -") as pbar:
for i in pbar: for i in pbar:
batch_audio = audio[i:i+batch_size] batch_audio = audio[i : i + batch_size]
batch_texts = texts[i:i+batch_size] batch_texts = texts[i : i + batch_size]
# inference # inference
results = pipe(batch_audio, batch_size=len(batch_audio)) results = pipe(batch_audio, batch_size=len(batch_audio))
...@@ -102,12 +108,22 @@ def eval_librispeech(model_id, num_samples=100, batch_size=4): ...@@ -102,12 +108,22 @@ def eval_librispeech(model_id, num_samples=100, batch_size=4):
references.extend(normalized_texts) references.extend(normalized_texts)
# word error rate computation # word error rate computation
wer = wer_metric.compute(predictions=predictions, references=references) * 100 wer = (
wer_metric.compute(predictions=predictions, references=references) * 100
)
pbar.set_description(f"Word Error Rate: {wer:.3f}%") pbar.set_description(f"Word Error Rate: {wer:.3f}%")
def eval_mmlu(model_path="gpt2", num_fewshot=1, batch_size=1, device="cuda:0", task_use_pretrained=False):
def eval_mmlu(
model_path="gpt2",
num_fewshot=1,
batch_size=1,
device="cuda:0",
task_use_pretrained=False,
):
try: try:
import vllm import vllm
VLLM_INSTALLED = True VLLM_INSTALLED = True
except ImportError: except ImportError:
VLLM_INSTALLED = False VLLM_INSTALLED = False
...@@ -133,12 +149,12 @@ def eval_mmlu(model_path="gpt2", num_fewshot=1, batch_size=1, device="cuda:0", t ...@@ -133,12 +149,12 @@ def eval_mmlu(model_path="gpt2", num_fewshot=1, batch_size=1, device="cuda:0", t
dtype="float16", dtype="float16",
trust_remote_code=True, trust_remote_code=True,
) )
model_args = ",".join([f"{k}={v}" for k,v in model_args.items()]) model_args = ",".join([f"{k}={v}" for k, v in model_args.items()])
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
model=model, model=model,
model_args=model_args, model_args=model_args,
tasks=['mmlu'], tasks=["mmlu"],
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
batch_size=batch_size, batch_size=batch_size,
device=device, device=device,
...@@ -147,7 +163,8 @@ def eval_mmlu(model_path="gpt2", num_fewshot=1, batch_size=1, device="cuda:0", t ...@@ -147,7 +163,8 @@ def eval_mmlu(model_path="gpt2", num_fewshot=1, batch_size=1, device="cuda:0", t
print(evaluator.make_table(results)) print(evaluator.make_table(results))
if __name__ == '__main__':
if __name__ == "__main__":
### PERPLEXITY ### PERPLEXITY
# model_path = 'mistralai/Mistral-7B-Instruct-v0.1' # model_path = 'mistralai/Mistral-7B-Instruct-v0.1'
# model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") # model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
......
...@@ -30,25 +30,27 @@ from transformers import ( ...@@ -30,25 +30,27 @@ from transformers import (
PreTrainedTokenizer, PreTrainedTokenizer,
) )
def eval_humaneval( def eval_humaneval(
model: PreTrainedModel, model: PreTrainedModel,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
out_path: str = "humaneval_out.jsonl", out_path: str = "humaneval_out.jsonl",
format_tabs: bool = True, format_tabs: bool = True,
): ):
problems = {example["task_id"]: example for example in load_dataset("openai_humaneval")["test"]} problems = {
example["task_id"]: example
for example in load_dataset("openai_humaneval")["test"]
}
samples = [] samples = []
for i, (task_id, task) in tqdm(enumerate(problems.items()), total=len(problems) ): for i, (task_id, task) in tqdm(enumerate(problems.items()), total=len(problems)):
if format_tabs: if format_tabs:
prompt = task["prompt"].replace(" ", "\t") prompt = task["prompt"].replace(" ", "\t")
else: else:
prompt = task["prompt"] prompt = task["prompt"]
batch_completions = generate_batch_completion( batch_completions = generate_batch_completion(model, tokenizer, prompt, 1)
model, tokenizer, prompt, 1
)
for sample in batch_completions: for sample in batch_completions:
result = dict( result = dict(
...@@ -58,9 +60,9 @@ def eval_humaneval( ...@@ -58,9 +60,9 @@ def eval_humaneval(
samples += [result] samples += [result]
with open(out_path, 'wb') as fp: with open(out_path, "wb") as fp:
for x in samples: for x in samples:
fp.write((json.dumps(x) + "\n").encode('utf-8')) fp.write((json.dumps(x) + "\n").encode("utf-8"))
results = evaluate_functional_correctness( results = evaluate_functional_correctness(
sample_file=out_path, sample_file=out_path,
...@@ -71,6 +73,7 @@ def eval_humaneval( ...@@ -71,6 +73,7 @@ def eval_humaneval(
print(results) print(results)
@torch.inference_mode() @torch.inference_mode()
def generate_batch_completion( def generate_batch_completion(
model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt, batch_size model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt, batch_size
...@@ -82,7 +85,7 @@ def generate_batch_completion( ...@@ -82,7 +85,7 @@ def generate_batch_completion(
generated_ids = model.generate( generated_ids = model.generate(
**inputs, **inputs,
use_cache=True, use_cache=True,
max_new_tokens=512, max_seq_len=512,
temperature=0.2, temperature=0.2,
top_p=0.95, top_p=0.95,
do_sample=True, do_sample=True,
...@@ -106,8 +109,9 @@ def generate_batch_completion( ...@@ -106,8 +109,9 @@ def generate_batch_completion(
return [filter_code(fix_indents(completion)) for completion in batch_completions] return [filter_code(fix_indents(completion)) for completion in batch_completions]
def check_correctness(problem: Dict, completion: str, timeout: float, def check_correctness(
completion_id: Optional[int] = None) -> Dict: problem: Dict, completion: str, timeout: float, completion_id: Optional[int] = None
) -> Dict:
""" """
Evaluates the functional correctness of a completion by running the test Evaluates the functional correctness of a completion by running the test
suite provided in the problem. suite provided in the problem.
...@@ -121,6 +125,7 @@ def check_correctness(problem: Dict, completion: str, timeout: float, ...@@ -121,6 +125,7 @@ def check_correctness(problem: Dict, completion: str, timeout: float,
# These system calls are needed when cleaning up tempdir. # These system calls are needed when cleaning up tempdir.
import os import os
import shutil import shutil
rmtree = shutil.rmtree rmtree = shutil.rmtree
rmdir = os.rmdir rmdir = os.rmdir
chdir = os.chdir chdir = os.chdir
...@@ -130,9 +135,12 @@ def check_correctness(problem: Dict, completion: str, timeout: float, ...@@ -130,9 +135,12 @@ def check_correctness(problem: Dict, completion: str, timeout: float,
# Construct the check program and run it. # Construct the check program and run it.
check_program = ( check_program = (
problem["prompt"] + completion + "\n" + problem["prompt"]
problem["test"] + "\n" + + completion
f"check({problem['entry_point']})" + "\n"
+ problem["test"]
+ "\n"
+ f"check({problem['entry_point']})"
) )
try: try:
...@@ -175,6 +183,7 @@ def check_correctness(problem: Dict, completion: str, timeout: float, ...@@ -175,6 +183,7 @@ def check_correctness(problem: Dict, completion: str, timeout: float,
def time_limit(seconds: float): def time_limit(seconds: float):
def signal_handler(signum, frame): def signal_handler(signum, frame):
raise TimeoutException("Timed out!") raise TimeoutException("Timed out!")
signal.setitimer(signal.ITIMER_REAL, seconds) signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler) signal.signal(signal.SIGALRM, signal_handler)
try: try:
...@@ -204,7 +213,7 @@ class TimeoutException(Exception): ...@@ -204,7 +213,7 @@ class TimeoutException(Exception):
class WriteOnlyStringIO(io.StringIO): class WriteOnlyStringIO(io.StringIO):
""" StringIO that throws an exception when it's read from """ """StringIO that throws an exception when it's read from"""
def read(self, *args, **kwargs): def read(self, *args, **kwargs):
raise IOError raise IOError
...@@ -216,12 +225,12 @@ class WriteOnlyStringIO(io.StringIO): ...@@ -216,12 +225,12 @@ class WriteOnlyStringIO(io.StringIO):
raise IOError raise IOError
def readable(self, *args, **kwargs): def readable(self, *args, **kwargs):
""" Returns True if the IO object can be read. """ """Returns True if the IO object can be read."""
return False return False
class redirect_stdin(contextlib._RedirectStream): # type: ignore class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = 'stdin' _stream = "stdin"
@contextlib.contextmanager @contextlib.contextmanager
...@@ -238,13 +247,14 @@ def chdir(root): ...@@ -238,13 +247,14 @@ def chdir(root):
finally: finally:
os.chdir(cwd) os.chdir(cwd)
def stream_jsonl(filename: str) -> Iterable[Dict]: def stream_jsonl(filename: str) -> Iterable[Dict]:
""" """
Parses each jsonl line and yields it as a dictionary Parses each jsonl line and yields it as a dictionary
""" """
if filename.endswith(".gz"): if filename.endswith(".gz"):
with open(filename, "rb") as gzfp: with open(filename, "rb") as gzfp:
with gzip.open(gzfp, 'rt') as fp: with gzip.open(gzfp, "rt") as fp:
for line in fp: for line in fp:
if any(not x.isspace() for x in line): if any(not x.isspace() for x in line):
yield json.loads(line) yield json.loads(line)
...@@ -254,6 +264,7 @@ def stream_jsonl(filename: str) -> Iterable[Dict]: ...@@ -254,6 +264,7 @@ def stream_jsonl(filename: str) -> Iterable[Dict]:
if any(not x.isspace() for x in line): if any(not x.isspace() for x in line):
yield json.loads(line) yield json.loads(line)
def estimate_pass_at_k( def estimate_pass_at_k(
num_samples: Union[int, List[int], np.ndarray], num_samples: Union[int, List[int], np.ndarray],
num_correct: Union[List[int], np.ndarray], num_correct: Union[List[int], np.ndarray],
...@@ -288,7 +299,10 @@ def evaluate_functional_correctness( ...@@ -288,7 +299,10 @@ def evaluate_functional_correctness(
n_workers: int = 4, n_workers: int = 4,
timeout: float = 3.0, timeout: float = 3.0,
): ):
problems = {example["task_id"]: example for example in load_dataset("openai_humaneval")["test"]} problems = {
example["task_id"]: example
for example in load_dataset("openai_humaneval")["test"]
}
# Check the generated samples against test suites. # Check the generated samples against test suites.
with ThreadPoolExecutor(max_workers=n_workers) as executor: with ThreadPoolExecutor(max_workers=n_workers) as executor:
...@@ -308,9 +322,11 @@ def evaluate_functional_correctness( ...@@ -308,9 +322,11 @@ def evaluate_functional_correctness(
n_samples += 1 n_samples += 1
if len(completion_id) < len(problems): if len(completion_id) < len(problems):
include_keys = list(problems.keys())[:len(completion_id)] include_keys = list(problems.keys())[: len(completion_id)]
print(f"Only found {len(completion_id)} solutions, reducing problems from {len(problems)}...") print(
problems = {k:v for k,v in problems.items() if k in include_keys} f"Only found {len(completion_id)} solutions, reducing problems from {len(problems)}..."
)
problems = {k: v for k, v in problems.items() if k in include_keys}
assert len(completion_id) == len(problems), "Some problems are not attempted." assert len(completion_id) == len(problems), "Some problems are not attempted."
...@@ -347,6 +363,7 @@ def evaluate_functional_correctness( ...@@ -347,6 +363,7 @@ def evaluate_functional_correctness(
return pass_at_k return pass_at_k
def reliability_guard(maximum_memory_bytes: Optional[int] = None): def reliability_guard(maximum_memory_bytes: Optional[int] = None):
""" """
This disables various destructive functions and prevents the generated code This disables various destructive functions and prevents the generated code
...@@ -364,19 +381,28 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): ...@@ -364,19 +381,28 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
if maximum_memory_bytes is not None: if maximum_memory_bytes is not None:
import resource import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(
if not platform.uname().system == 'Darwin': resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) )
resource.setrlimit(
resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
)
if not platform.uname().system == "Darwin":
resource.setrlimit(
resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
)
faulthandler.disable() faulthandler.disable()
import builtins import builtins
builtins.exit = None builtins.exit = None
builtins.quit = None builtins.quit = None
import os import os
os.environ['OMP_NUM_THREADS'] = '1'
os.environ["OMP_NUM_THREADS"] = "1"
os.kill = None os.kill = None
os.system = None os.system = None
...@@ -407,25 +433,32 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): ...@@ -407,25 +433,32 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
os.chdir = None os.chdir = None
import shutil import shutil
shutil.rmtree = None shutil.rmtree = None
shutil.move = None shutil.move = None
shutil.chown = None shutil.chown = None
import subprocess import subprocess
subprocess.Popen = None # type: ignore subprocess.Popen = None # type: ignore
import sys import sys
sys.modules['ipdb'] = None
sys.modules['joblib'] = None
sys.modules['resource'] = None
sys.modules['psutil'] = None
sys.modules['tkinter'] = None
if __name__ == '__main__': sys.modules["ipdb"] = None
sys.modules["joblib"] = None
sys.modules["resource"] = None
sys.modules["psutil"] = None
sys.modules["tkinter"] = None
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["TOKENIZERS_PARALLELISM"] = "true"
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
model_path = 'TheBloke/zephyr-7B-beta-AWQ'
model = AutoAWQForCausalLM.from_quantized(model_path, device_map="auto", max_new_tokens=2048) model_path = "TheBloke/zephyr-7B-beta-AWQ"
model = AutoAWQForCausalLM.from_quantized(
model_path, device_map="auto", max_seq_len=2048
)
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
eval_humaneval(model, tokenizer) eval_humaneval(model, tokenizer)
...@@ -15,17 +15,20 @@ try: ...@@ -15,17 +15,20 @@ try:
from scipy.stats import bayes_mvs from scipy.stats import bayes_mvs
from scipy.stats import t as student_t from scipy.stats import t as student_t
from scipy.stats.mstats import mquantiles_cimj from scipy.stats.mstats import mquantiles_cimj
SCIPY_INSTALLED = True SCIPY_INSTALLED = True
except: except:
SCIPY_INSTALLED = False SCIPY_INSTALLED = False
@torch.jit.script @torch.jit.script
def rel_entr(x, y): def rel_entr(x, y):
mask = (x > 0) & (y > 0) mask = (x > 0) & (y > 0)
result = torch.where(mask, x * torch.log(x / y), torch.zeros_like(x)) result = torch.where(mask, x * torch.log(x / y), torch.zeros_like(x))
result[(x > 0) & (y <= 0)] = float('inf') result[(x > 0) & (y <= 0)] = float("inf")
return result return result
def bin_conf(p, n, z): def bin_conf(p, n, z):
# Binomial distribution confidence bounds # Binomial distribution confidence bounds
# Bayes estimator when p is degenerate # Bayes estimator when p is degenerate
...@@ -33,15 +36,23 @@ def bin_conf(p, n, z): ...@@ -33,15 +36,23 @@ def bin_conf(p, n, z):
p = 1 / (n + 2) p = 1 / (n + 2)
if p == 1: if p == 1:
p = 1 - 1 / (n + 2) p = 1 - 1 / (n + 2)
return z * torch.sqrt(p*(1-p)/n) return z * torch.sqrt(p * (1 - p) / n)
def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel, tokenizer: PreTrainedTokenizer, seqlen: int):
def eval_kl_divergence(
ref_model: PreTrainedModel,
eval_model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
seqlen: int,
):
if not SCIPY_INSTALLED: if not SCIPY_INSTALLED:
raise Exception("SciPy needs to be installed for KL Divergence evaluation: pip install scipy") raise Exception(
"SciPy needs to be installed for KL Divergence evaluation: pip install scipy"
)
# load dataset # load dataset
data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
data = tokenizer("\n\n".join(data['text']), return_tensors='pt') data = tokenizer("\n\n".join(data["text"]), return_tensors="pt")
data = data.input_ids.to(ref_model.device) data = data.input_ids.to(ref_model.device)
n_samples = data.numel() // seqlen n_samples = data.numel() // seqlen
...@@ -59,9 +70,9 @@ def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel, ...@@ -59,9 +70,9 @@ def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel,
# start eval # start eval
with tqdm(range(n_samples), desc="KL Div") as progress_bar: with tqdm(range(n_samples), desc="KL Div") as progress_bar:
for i in progress_bar: for i in progress_bar:
start_index = (i * seqlen) start_index = i * seqlen
end_index = ((i + 1) * seqlen) end_index = (i + 1) * seqlen
batch_len = end_index-start_index batch_len = end_index - start_index
batch = data[:, start_index:end_index] batch = data[:, start_index:end_index]
# get logits # get logits
...@@ -97,9 +108,9 @@ def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel, ...@@ -97,9 +108,9 @@ def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel,
f"Top 10: {top10 / samples:.4g}" f"Top 10: {top10 / samples:.4g}"
) )
z = student_t.ppf(1 - alpha/2, samples) z = student_t.ppf(1 - alpha / 2, samples)
m_conf = z*np.sqrt(np.mean([k**2 for k in kls])/len(kls)) m_conf = z * np.sqrt(np.mean([k**2 for k in kls]) / len(kls))
m, _, __ = bayes_mvs(kls, 1-alpha) m, _, __ = bayes_mvs(kls, 1 - alpha)
q90 = np.quantile(kls, 0.90) q90 = np.quantile(kls, 0.90)
q95 = np.quantile(kls, 0.95) q95 = np.quantile(kls, 0.95)
q99 = np.quantile(kls, 0.99) q99 = np.quantile(kls, 0.99)
...@@ -116,20 +127,33 @@ def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel, ...@@ -116,20 +127,33 @@ def eval_kl_divergence(ref_model: PreTrainedModel, eval_model: PreTrainedModel,
print(f"max: {np.max(kls):.4g}") print(f"max: {np.max(kls):.4g}")
print(" -- ") print(" -- ")
print("Reference top token in eval top-n probability:") print("Reference top token in eval top-n probability:")
print(f" ** ref_top1: {top1 / samples:.4g} ± {bin_conf(top1/samples, samples, z):.4g}") print(
print(f" ** ref_top5: {top5 / samples:.4g} ± {bin_conf(top5/samples, samples, z):.4g}") f" ** ref_top1: {top1 / samples:.4g} ± {bin_conf(top1/samples, samples, z):.4g}"
print(f" ** ref_top10: {top10 / samples:4g} ± {bin_conf(top10/samples, samples, z):.4g}") )
print(
f" ** ref_top5: {top5 / samples:.4g} ± {bin_conf(top5/samples, samples, z):.4g}"
)
print(
f" ** ref_top10: {top10 / samples:4g} ± {bin_conf(top10/samples, samples, z):.4g}"
)
print("Eval top token in reference top-n probability:") print("Eval top token in reference top-n probability:")
print(f" ** eval_top5: {eval_top5 / samples:.4g} ± {bin_conf(eval_top5/samples, samples, z):.4g}") print(
print(f" ** eval_top10: {eval_top10 / samples:4g} ± {bin_conf(eval_top10/samples, samples, z):.4g}") f" ** eval_top5: {eval_top5 / samples:.4g} ± {bin_conf(eval_top5/samples, samples, z):.4g}"
)
print(
f" ** eval_top10: {eval_top10 / samples:4g} ± {bin_conf(eval_top10/samples, samples, z):.4g}"
)
if __name__ == '__main__': if __name__ == "__main__":
# ref_model_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" # ref_model_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
# eval_model_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T" # eval_model_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T"
ref_model_path = eval_model_path = "gpt2" ref_model_path = eval_model_path = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(ref_model_path) tokenizer = AutoTokenizer.from_pretrained(ref_model_path)
ref_model = AutoModelForCausalLM.from_pretrained(ref_model_path, device_map="auto") ref_model = AutoModelForCausalLM.from_pretrained(ref_model_path, device_map="auto")
eval_model = AutoModelForCausalLM.from_pretrained(eval_model_path, device_map="auto") eval_model = AutoModelForCausalLM.from_pretrained(
eval_model_path, device_map="auto"
)
eval_kl_divergence(ref_model, eval_model, tokenizer, seqlen=1024) eval_kl_divergence(ref_model, eval_model, tokenizer, seqlen=1024)
import os import os
import json import json
import logging
from typing import Dict, Optional, List from typing import Dict, Optional, List
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field
from transformers.utils.hub import PushToHubMixin, cached_file from transformers.utils.hub import PushToHubMixin, cached_file
@dataclass @dataclass
class AwqConfig(PushToHubMixin): class AwqConfig(PushToHubMixin):
quant_method: str = field(default="awq") quant_method: str = field(default="awq")
zero_point: bool = field(default=True) zero_point: bool = field(default=True)
q_group_size: int = field(default=128) q_group_size: int = field(default=128)
w_bit: int = field(default=4) w_bit: int = field(default=4)
version: str = field(default="GEMM") version: str = field(default="gemm")
config_file_name = "quant_config.json" config_file_name = "config.json"
modules_to_not_convert: Optional[List] = None modules_to_not_convert: Optional[List] = None
def save_pretrained(self, save_dir: str, **kwargs):
logging.warning(
"`quant_config.json` is being deprecated in the future"
" in favor of quantization_config in config.json."
)
with open(os.path.join(save_dir, self.config_file_name), "w+", encoding="utf-8") as file:
file.write(json.dumps(self.to_dict(), indent=4))
@classmethod @classmethod
def from_dict(cls, quant_config: Dict={}): def from_dict(cls, quant_config: Dict = {}):
if not quant_config: if not quant_config:
quant_config = cls() quant_config = cls()
else: else:
quant_config = cls(**quant_config) quant_config = cls(**quant_config)
quant_config.version = quant_config.version.lower()
return quant_config return quant_config
...@@ -63,11 +56,18 @@ class AwqConfig(PushToHubMixin): ...@@ -63,11 +56,18 @@ class AwqConfig(PushToHubMixin):
_commit_hash=commit_hash, _commit_hash=commit_hash,
) )
quant_config = None
if os.path.exists(resolved_config_file): if os.path.exists(resolved_config_file):
with open(resolved_config_file, 'r', encoding="utf-8") as file: with open(resolved_config_file, "r", encoding="utf-8") as file:
loaded_config = json.loads(file.read()) loaded_config = json.loads(file.read())
quant_config = cls(**loaded_config)
else: quant_config = loaded_config.get("quantization_config")
if quant_config is not None:
awq_config = cls.from_transformers_dict(cls, quant_config)
quant_config = cls(**awq_config)
if quant_config is None:
quant_config = cls() quant_config = cls()
return quant_config return quant_config
...@@ -90,3 +90,13 @@ class AwqConfig(PushToHubMixin): ...@@ -90,3 +90,13 @@ class AwqConfig(PushToHubMixin):
"version": self.version.lower(), "version": self.version.lower(),
"modules_to_not_convert": self.modules_to_not_convert, "modules_to_not_convert": self.modules_to_not_convert,
} }
def from_transformers_dict(self, transformers_dict: Dict):
return {
"quant_method": transformers_dict.get("quant_method"),
"zero_point": transformers_dict.get("zero_point"),
"q_group_size": transformers_dict.get("group_size"),
"w_bit": transformers_dict.get("bits"),
"version": transformers_dict.get("version"),
"modules_to_not_convert": transformers_dict.get("modules_to_not_convert"),
}
...@@ -6,13 +6,14 @@ from awq.modules.fused.block import LlamaLikeBlock ...@@ -6,13 +6,14 @@ from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldAquilaDecoderLayer, LlamaDecoderLayer as OldAquilaDecoderLayer,
LlamaForCausalLM as OldAquilaForCausalLM LlamaForCausalLM as OldAquilaForCausalLM,
) )
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class AquilaAWQForCausalLM(BaseAWQForCausalLM): class AquilaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "AquilaDecoderLayer" layer_type = "AquilaDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_seq_len_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model: OldAquilaForCausalLM): def fuse_layers(model: OldAquilaForCausalLM):
...@@ -25,50 +26,62 @@ class AquilaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -25,50 +26,62 @@ class AquilaAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def get_act_for_scaling(module: OldAquilaDecoderLayer): def get_act_for_scaling(module: OldAquilaDecoderLayer):
return dict( return dict(is_scalable=False)
is_scalable=False
)
@staticmethod @staticmethod
def move_embed(model: OldAquilaForCausalLM, device: str): def move_embed(model: OldAquilaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: OldAquilaDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(
module: OldAquilaDecoderLayer, input_feat, module_kwargs
):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
dict(
prev_op=module.input_layernorm, prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj, layers=[
module.self_attn.k_proj, module.self_attn.v_proj], module.self_attn.q_proj,
inp=input_feat['self_attn.q_proj'], module.self_attn.k_proj,
module2inspect=module.self_attn, kwargs=module_kwargs, module.self_attn.v_proj,
)) ],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out # attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict( layers.append(
dict(
prev_op=module.self_attn.v_proj, prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj], layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'], inp=input_feat["self_attn.o_proj"],
)) )
)
# linear 1 # linear 1
layers.append(dict( layers.append(
dict(
prev_op=module.post_attention_layernorm, prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj], layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'], inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp, module2inspect=module.mlp,
)) )
)
# linear 2 # linear 2
layers.append(dict( layers.append(
dict(
prev_op=module.mlp.up_proj, prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj], layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'], inp=input_feat["mlp.down_proj"],
)) )
)
return layers return layers
...@@ -78,8 +91,9 @@ class AquilaFuser: ...@@ -78,8 +91,9 @@ class AquilaFuser:
self.model = model self.model = model
self.aquila_blocks: List[Tuple[str, OldAquilaDecoderLayer]] = [ self.aquila_blocks: List[Tuple[str, OldAquilaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules() (name, module)
if 'AquilaDecoderLayer'.lower() in module.__class__.__name__.lower() for name, module in self.model.named_modules()
if "AquilaDecoderLayer".lower() in module.__class__.__name__.lower()
] ]
def fuse_transformer(self): def fuse_transformer(self):
...@@ -92,17 +106,17 @@ class AquilaFuser: ...@@ -92,17 +106,17 @@ class AquilaFuser:
module, module,
module.self_attn.q_proj, module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj,
) )
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight, module.input_layernorm.variance_epsilon
module.input_layernorm.variance_epsilon
) )
norm_2 = FasterTransformerRMSNorm( norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight, module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon module.post_attention_layernorm.variance_epsilon,
) )
blocks.append(LlamaLikeBlock( blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size, hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads, n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads, n_kv_heads=self.model.config.num_key_value_heads,
...@@ -112,8 +126,9 @@ class AquilaFuser: ...@@ -112,8 +126,9 @@ class AquilaFuser:
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
max_seq_len=self.model.config.max_new_tokens max_seq_len=self.model.config.max_seq_len,
)) )
)
self.model.model = LlamaLikeModel( self.model.model = LlamaLikeModel(
self.model.config.vocab_size, self.model.config.vocab_size,
......
import os import os
import logging
from transformers import AutoConfig from transformers import AutoConfig
from awq.models import * from awq.models import *
from awq.models.base import BaseAWQForCausalLM from awq.models.base import BaseAWQForCausalLM
...@@ -21,7 +22,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -21,7 +22,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"qwen": QwenAWQForCausalLM, "qwen": QwenAWQForCausalLM,
"baichuan": BaichuanAWQForCausalLM, "baichuan": BaichuanAWQForCausalLM,
"llava": LlavaAWQForCausalLM, "llava": LlavaAWQForCausalLM,
"qwen2": Qwen2AWQForCausalLM "qwen2": Qwen2AWQForCausalLM,
} }
...@@ -47,7 +48,7 @@ class AutoAWQForCausalLM: ...@@ -47,7 +48,7 @@ class AutoAWQForCausalLM:
self, self,
model_path, model_path,
trust_remote_code=True, trust_remote_code=True,
safetensors=False, safetensors=True,
device_map=None, device_map=None,
**model_init_kwargs, **model_init_kwargs,
) -> BaseAWQForCausalLM: ) -> BaseAWQForCausalLM:
...@@ -69,7 +70,7 @@ class AutoAWQForCausalLM: ...@@ -69,7 +70,7 @@ class AutoAWQForCausalLM:
self, self,
quant_path, quant_path,
quant_filename="", quant_filename="",
max_new_tokens=None, max_seq_len=2048,
trust_remote_code=True, trust_remote_code=True,
fuse_layers=True, fuse_layers=True,
use_exllama=False, use_exllama=False,
...@@ -83,11 +84,18 @@ class AutoAWQForCausalLM: ...@@ -83,11 +84,18 @@ class AutoAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size) os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
if config_kwargs.get("max_new_tokens") is not None:
max_seq_len = config_kwargs["max_new_tokens"]
logging.warning(
"max_new_tokens argument is deprecated... gracefully "
"setting max_seq_len=max_new_tokens."
)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, quant_path,
model_type, model_type,
quant_filename, quant_filename,
max_new_tokens, max_seq_len,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, fuse_layers=fuse_layers,
use_exllama=use_exllama, use_exllama=use_exllama,
......
...@@ -8,9 +8,10 @@ from transformers.models.llama.modeling_llama import ( ...@@ -8,9 +8,10 @@ from transformers.models.llama.modeling_llama import (
) )
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class BaichuanAWQForCausalLM(BaseAWQForCausalLM): class BaichuanAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "BaichuanLayer" layer_type = "BaichuanLayer"
max_new_tokens_key = "model_max_length" max_seq_len_key = "model_max_length"
@staticmethod @staticmethod
def fuse_layers(model): def fuse_layers(model):
...@@ -23,9 +24,7 @@ class BaichuanAWQForCausalLM(BaseAWQForCausalLM): ...@@ -23,9 +24,7 @@ class BaichuanAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def get_act_for_scaling(module): def get_act_for_scaling(module):
return dict( return dict(is_scalable=False)
is_scalable=False
)
@staticmethod @staticmethod
def move_embed(model, device: str): def move_embed(model, device: str):
...@@ -37,12 +36,15 @@ class BaichuanAWQForCausalLM(BaseAWQForCausalLM): ...@@ -37,12 +36,15 @@ class BaichuanAWQForCausalLM(BaseAWQForCausalLM):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
dict(
prev_op=module.input_layernorm, prev_op=module.input_layernorm,
layers=[module.self_attn.W_pack], layers=[module.self_attn.W_pack],
inp=input_feat['self_attn.W_pack'], inp=input_feat["self_attn.W_pack"],
module2inspect=module.self_attn, kwargs=module_kwargs, module2inspect=module.self_attn,
)) kwargs=module_kwargs,
)
)
# # attention out # # attention out
# # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 # # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
...@@ -55,26 +57,32 @@ class BaichuanAWQForCausalLM(BaseAWQForCausalLM): ...@@ -55,26 +57,32 @@ class BaichuanAWQForCausalLM(BaseAWQForCausalLM):
# attention out # attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
layers.append(dict( layers.append(
dict(
prev_op=module.self_attn.W_pack, prev_op=module.self_attn.W_pack,
layers=[module.self_attn.o_proj], layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'], inp=input_feat["self_attn.o_proj"],
)) )
)
# linear 1 # linear 1
layers.append(dict( layers.append(
dict(
prev_op=module.post_attention_layernorm, prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj], layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'], inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp, module2inspect=module.mlp,
)) )
)
# linear 2 # linear 2
layers.append(dict( layers.append(
dict(
prev_op=module.mlp.up_proj, prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj], layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'], inp=input_feat["mlp.down_proj"],
)) )
)
return layers return layers
...@@ -84,8 +92,9 @@ class BaichuanFuser: ...@@ -84,8 +92,9 @@ class BaichuanFuser:
self.model = model self.model = model
self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [ self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules() (name, module)
if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower() for name, module in self.model.named_modules()
if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower()
] ]
def fuse_transformer(self): def fuse_transformer(self):
...@@ -101,14 +110,14 @@ class BaichuanFuser: ...@@ -101,14 +110,14 @@ class BaichuanFuser:
# ) # )
qkv = module.self_attn.W_pack qkv = module.self_attn.W_pack
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight, module.input_layernorm.epsilon
module.input_layernorm.epsilon
) )
norm_2 = FasterTransformerRMSNorm( norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight, module.post_attention_layernorm.weight,
module.post_attention_layernorm.epsilon module.post_attention_layernorm.epsilon,
) )
blocks.append(LlamaLikeBlock( blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size, hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads, n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_attention_heads, n_kv_heads=self.model.config.num_attention_heads,
...@@ -118,9 +127,10 @@ class BaichuanFuser: ...@@ -118,9 +127,10 @@ class BaichuanFuser:
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
max_seq_len=self.model.config.max_new_tokens, max_seq_len=self.model.config.max_seq_len,
use_alibi=True use_alibi=True,
)) )
)
self.model.model = LlamaLikeModel( self.model.model = LlamaLikeModel(
self.model.config.vocab_size, self.model.config.vocab_size,
......
...@@ -6,8 +6,9 @@ import transformers ...@@ -6,8 +6,9 @@ import transformers
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import List, Union from typing import List, Union, Dict
from safetensors.torch import save_file from safetensors.torch import save_file
from typing_extensions import Doc, Annotated
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
...@@ -27,6 +28,7 @@ from transformers import ( ...@@ -27,6 +28,7 @@ from transformers import (
PretrainedConfig, PretrainedConfig,
AutoProcessor, AutoProcessor,
CLIPImageProcessor, CLIPImageProcessor,
PreTrainedTokenizer,
) )
from accelerate.big_modeling import ( from accelerate.big_modeling import (
init_empty_weights, init_empty_weights,
...@@ -64,8 +66,21 @@ TRANSFORMERS_AUTO_MAPPING_DICT = { ...@@ -64,8 +66,21 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
class BaseAWQForCausalLM(nn.Module): class BaseAWQForCausalLM(nn.Module):
def __init__( def __init__(
self, model, model_type, is_quantized, config, quant_config, processor self,
model: Annotated[PreTrainedModel, Doc("The pretrained or quantized model.")],
model_type: Annotated[str, Doc("The model type, found in config.json.")],
is_quantized: Annotated[
bool, Doc("Indicates if the current model is quantized.")
],
config: Annotated[PretrainedConfig, Doc("The config of the model.")],
quant_config: Annotated[
AwqConfig, Doc("The quantization config of the model.")
],
processor: Annotated[
AutoProcessor, Doc("An optional processor, e.g. for vision models.")
],
): ):
"""The base model for all AutoAWQ models."""
super().__init__() super().__init__()
self.model: PreTrainedModel = model self.model: PreTrainedModel = model
self.model_type: str = model_type self.model_type: str = model_type
...@@ -75,30 +90,68 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -75,30 +90,68 @@ class BaseAWQForCausalLM(nn.Module):
self.quant_config: AwqConfig = quant_config self.quant_config: AwqConfig = quant_config
self.processor: CLIPImageProcessor = processor self.processor: CLIPImageProcessor = processor
def to(self, device: str): def to(self, device: Annotated[str, Doc("The device to move your model to.")]):
"""A utility function for moving the model to a device."""
return self.model.to(device) return self.model.to(device)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
"""A forward function that mimics the torch forward."""
return self.model(*args, **kwargs) return self.model(*args, **kwargs)
def generate(self, *args, **kwargs): def generate(self, *args, **kwargs):
"""A generate function that mimics the HF generate function."""
with torch.inference_mode(): with torch.inference_mode():
return self.model.generate(*args, **kwargs) return self.model.generate(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
def quantize( def quantize(
self, self,
tokenizer=None, tokenizer: Annotated[
quant_config={}, PreTrainedTokenizer, Doc("The tokenizer to use for quantization.")
calib_data: Union[str, List[str]] = "pileval", ] = None,
split="train", quant_config: Annotated[
text_column="text", Dict, Doc("The quantization config you want to use.")
duo_scaling=True, ] = {},
modules_to_not_convert=None, calib_data: Annotated[
export_compatible=False, Union[str, List[str]],
Doc(
"The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples."
),
] = "pileval",
split: Annotated[str, Doc("The split of calib_data.")] = "train",
text_column: Annotated[str, Doc("The text column of calib_data.")] = "text",
duo_scaling: Annotated[
bool, Doc("Whether to scale using both w/x or just x.")
] = True,
export_compatible: Annotated[
bool,
Doc(
"This argument avoids real quantization by only applying the scales without quantizing down to FP16."
),
] = False,
): ):
"""
The main quantization function that you can use to quantize your model.
Example:
```python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = "..."
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
model.quantize(tokenizer, quant_config)
```
"""
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
if hasattr(self, "modules_to_not_convert"):
self.quant_config.modules_to_not_convert = self.modules_to_not_convert
self.quantizer = AwqQuantizer( self.quantizer = AwqQuantizer(
self, self,
self.model, self.model,
...@@ -111,7 +164,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -111,7 +164,7 @@ class BaseAWQForCausalLM(nn.Module):
split, split,
text_column, text_column,
duo_scaling, duo_scaling,
modules_to_not_convert=modules_to_not_convert, modules_to_not_convert=self.quant_config.modules_to_not_convert,
export_compatible=export_compatible, export_compatible=export_compatible,
) )
self.quantizer.quantize() self.quantizer.quantize()
...@@ -124,6 +177,9 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -124,6 +177,9 @@ class BaseAWQForCausalLM(nn.Module):
A utility function for the following scenario. Note that save_quantized will A utility function for the following scenario. Note that save_quantized will
overwrite existing weights if you use the same quant_path. overwrite existing weights if you use the same quant_path.
Example:
```python
model.quantize( model.quantize(
tokenizer, tokenizer,
quant_config=quant_config, quant_config=quant_config,
...@@ -132,6 +188,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -132,6 +188,7 @@ class BaseAWQForCausalLM(nn.Module):
model.save_quantized(...) # produces GGUF/other compat weights model.save_quantized(...) # produces GGUF/other compat weights
model.pack(...) # makes the model CUDA compat model.pack(...) # makes the model CUDA compat
model.save_quantized(...) # produces CUDA compat weights model.save_quantized(...) # produces CUDA compat weights
```
""" """
self.quantizer.pack() self.quantizer.pack()
...@@ -139,7 +196,16 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -139,7 +196,16 @@ class BaseAWQForCausalLM(nn.Module):
def fuse_layers(model): def fuse_layers(model):
pass pass
def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"): def save_quantized(
self,
save_dir: Annotated[str, Doc("The directory to save your model to.")],
safetensors: Annotated[
bool, Doc("Whether to save the model as safetensors or torch files.")
] = True,
shard_size: Annotated[
str, Doc("The shard size for sharding large models into multiple chunks.")
] = "5GB",
):
save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir
# Save model # Save model
...@@ -154,7 +220,6 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -154,7 +220,6 @@ class BaseAWQForCausalLM(nn.Module):
self.model.config.quantization_config = self.quant_config.to_transformers_dict() self.model.config.quantization_config = self.quant_config.to_transformers_dict()
self.model.generation_config.do_sample = True self.model.generation_config.do_sample = True
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict()) self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
self.quant_config.save_pretrained(save_dir)
# Vision transformers have a processor # Vision transformers have a processor
if self.processor is not None: if self.processor is not None:
...@@ -195,14 +260,37 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -195,14 +260,37 @@ class BaseAWQForCausalLM(nn.Module):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
self, self,
model_path, model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
model_type, model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
torch_dtype: torch.dtype = torch.float16, torch_dtype: Annotated[
trust_remote_code=True, torch.dtype,
safetensors=False, Doc(
device_map=None, "The dtype to load the model as. May not work with other values than float16."
**model_init_kwargs, ),
] = torch.float16,
trust_remote_code: Annotated[
bool,
Doc(
"Useful for Huggingface repositories that have not been integrated into transformers yet."
),
] = True,
safetensors: Annotated[
bool, Doc("Whether to download/load safetensors instead of torch weights.")
] = True,
device_map: Annotated[
Union[str, Dict],
Doc(
"A device map that will be passed onto the model loading method from transformers."
),
] = None,
**model_init_kwargs: Annotated[
Dict,
Doc(
"Additional kwargs that are passed to the model during initialization."
),
],
): ):
"""A method for initialization of pretrained models, usually in FP16."""
# Get weights path and quant config # Get weights path and quant config
model_weights_path, config, quant_config = self._load_config( model_weights_path, config, quant_config = self._load_config(
self, model_path, "", safetensors, trust_remote_code=trust_remote_code self, model_path, "", safetensors, trust_remote_code=trust_remote_code
...@@ -240,31 +328,70 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -240,31 +328,70 @@ class BaseAWQForCausalLM(nn.Module):
@classmethod @classmethod
def from_quantized( def from_quantized(
self, self,
model_path, model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
model_type, model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
model_filename="", model_filename: Annotated[
max_new_tokens=None, str, Doc("Load a specific model's filename by specifying this argument.")
torch_dtype=torch.float16, ] = "",
trust_remote_code=True, max_seq_len: Annotated[
safetensors=True, int,
is_quantized=True, Doc(
fuse_layers=False, "The maximum sequence cached sequence length of the model. Larger values may increase loading time and memory usage."
use_exllama=False, ),
use_exllama_v2=False, ] = None,
version="GEMM", torch_dtype: Annotated[
device_map="balanced", torch.dtype,
offload_folder=None, Doc(
**config_kwargs, "The dtype to load the model as. May not work with other values than float16."
),
] = torch.float16,
trust_remote_code: Annotated[
bool,
Doc(
"Useful for Huggingface repositories that have not been integrated into transformers yet."
),
] = True,
safetensors: Annotated[
bool, Doc("Whether to download/load safetensors instead of torch weights.")
] = True,
fuse_layers: Annotated[
bool,
Doc(
"Whether to use fused/optimized combination of layers for increased speed."
),
] = True,
use_exllama: Annotated[
bool, Doc("Whether to map the weights to ExLlamaV1 kernels.")
] = False,
use_exllama_v2: Annotated[
bool, Doc("Whether to map the weights to ExLlamaV2 kernels.")
] = False,
device_map: Annotated[
Union[str, Dict],
Doc(
"A device map that will be passed onto the model loading method from transformers."
),
] = "balanced",
offload_folder: Annotated[
str,
Doc("The folder ot offload the model to."),
] = None,
**config_kwargs: Annotated[
Dict,
Doc(
"Additional kwargs that are passed to the config during initialization."
),
],
): ):
"""A method for initialization of a quantized model, usually in INT4."""
# [STEP 1-2] Load weights path and configs # [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config( model_weights_path, config, quant_config = self._load_config(
self, self,
model_path, model_path,
model_filename, model_filename,
safetensors, safetensors,
version,
trust_remote_code, trust_remote_code,
max_new_tokens=max_new_tokens, max_seq_len=max_seq_len,
**config_kwargs, **config_kwargs,
) )
...@@ -306,7 +433,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -306,7 +433,7 @@ class BaseAWQForCausalLM(nn.Module):
if fuse_layers: if fuse_layers:
self.fuse_layers(model) self.fuse_layers(model)
if quant_config.version == "Marlin": if quant_config.version == "marlin":
model = marlin_post_init(model) model = marlin_post_init(model)
elif use_exllama: elif use_exllama:
...@@ -316,14 +443,14 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -316,14 +443,14 @@ class BaseAWQForCausalLM(nn.Module):
# creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size # creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size
model = exllamav2_post_init( model = exllamav2_post_init(
model, model,
max_input_len=max_new_tokens or 2048, max_input_len=max_seq_len or 2048,
max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)), max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)),
) )
return self( return self(
model, model,
model_type, model_type,
is_quantized=is_quantized, is_quantized=True,
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
processor=None, processor=None,
...@@ -334,9 +461,8 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -334,9 +461,8 @@ class BaseAWQForCausalLM(nn.Module):
model_path, model_path,
model_filename, model_filename,
safetensors=True, safetensors=True,
version="GEMM",
trust_remote_code=True, trust_remote_code=True,
max_new_tokens=4096, max_seq_len=4096,
**config_kwargs, **config_kwargs,
): ):
# [STEP 1] Download model if path is not a directory # [STEP 1] Download model if path is not a directory
...@@ -359,22 +485,22 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -359,22 +485,22 @@ class BaseAWQForCausalLM(nn.Module):
quant_config = AwqConfig.from_pretrained(model_path) quant_config = AwqConfig.from_pretrained(model_path)
# Load model config and set max generation length # Load model config and set max generation length
if max_new_tokens is None and hasattr(self, "max_new_tokens_key"): if max_seq_len is None and hasattr(self, "max_seq_len_key"):
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code, **config_kwargs model_path, trust_remote_code=trust_remote_code, **config_kwargs
) )
config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048) config.max_seq_len = getattr(config, self.max_seq_len_key, 2048)
# To add the generate support for Multi-modal models as well # To add the generate support for Multi-modal models as well
if hasattr(config, "text_config"): if hasattr(config, "text_config"):
config.text_config.max_new_tokens = getattr( config.text_config.max_seq_len = getattr(
config, self.max_new_tokens_key, 2048 config, self.max_seq_len_key, 2048
) )
else: else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens max_seq_len = 2048 if max_seq_len is None else max_seq_len
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code, **config_kwargs model_path, trust_remote_code=trust_remote_code, **config_kwargs
) )
config.max_new_tokens = max_new_tokens config.max_seq_len = max_seq_len
return model_weights_path, config, quant_config return model_weights_path, config, quant_config
...@@ -383,7 +509,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -383,7 +509,7 @@ class BaseAWQForCausalLM(nn.Module):
): ):
# Real quantization of weights # Real quantization of weights
assert not ( assert not (
version == "GEMV" and (use_exllama or use_exllama_v2) version == "gemv" and (use_exllama or use_exllama_v2)
), "Exllama kernels only support GEMM version." ), "Exllama kernels only support GEMM version."
# Get blocks of model # Get blocks of model
...@@ -405,15 +531,15 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -405,15 +531,15 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear # Replace nn.Linear with WQLinear
for name, module in named_linears.items(): for name, module in named_linears.items():
if version == "Marlin": if version == "marlin":
q_linear_module = WQLinear_Marlin q_linear_module = WQLinear_Marlin
elif use_exllama: elif use_exllama:
q_linear_module = WQLinear_Exllama q_linear_module = WQLinear_Exllama
elif use_exllama_v2: elif use_exllama_v2:
q_linear_module = WQLinear_ExllamaV2 q_linear_module = WQLinear_ExllamaV2
elif version == "GEMM": elif version == "gemm":
q_linear_module = WQLinear_GEMM q_linear_module = WQLinear_GEMM
elif version == "GEMV": elif version == "gemv":
q_linear_module = WQLinear_GEMV q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear( q_linear = q_linear_module.from_linear(
......
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomBlock from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomBlock
class BloomAWQForCausalLM(BaseAWQForCausalLM): class BloomAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "BloomBlock" layer_type = "BloomBlock"
...@@ -14,25 +15,30 @@ class BloomAWQForCausalLM(BaseAWQForCausalLM): ...@@ -14,25 +15,30 @@ class BloomAWQForCausalLM(BaseAWQForCausalLM):
is_scalable=True, is_scalable=True,
scale_name="mlp.gelu_impl", scale_name="mlp.gelu_impl",
scale_layer=module.mlp.gelu_impl, scale_layer=module.mlp.gelu_impl,
scale_shape=module.mlp.dense_h_to_4h.out_features scale_shape=module.mlp.dense_h_to_4h.out_features,
) )
@staticmethod @staticmethod
def move_embed(model: BloomForCausalLM, device: str): def move_embed(model: BloomForCausalLM, device: str):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device) model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device) model.transformer.word_embeddings_layernorm = (
model.transformer.word_embeddings_layernorm.to(device)
)
@staticmethod @staticmethod
def get_layers_for_scaling(module: BloomBlock, input_feat, module_kwargs): def get_layers_for_scaling(module: BloomBlock, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
dict(
prev_op=module.input_layernorm, prev_op=module.input_layernorm,
layers=[module.self_attention.query_key_value], layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'], inp=input_feat["self_attention.query_key_value"],
module2inspect=module, kwargs=module_kwargs, module2inspect=module,
)) kwargs=module_kwargs,
)
)
# attention out # attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469 # Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
""" """
...@@ -43,17 +49,22 @@ class BloomAWQForCausalLM(BaseAWQForCausalLM): ...@@ -43,17 +49,22 @@ class BloomAWQForCausalLM(BaseAWQForCausalLM):
)) ))
""" """
# linear 1 # linear 1
layers.append(dict( layers.append(
dict(
prev_op=module.post_attention_layernorm, prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h], layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'], inp=input_feat["mlp.dense_h_to_4h"],
module2inspect=module, kwargs=module_kwargs, module2inspect=module,
)) kwargs=module_kwargs,
)
)
# linear 2 # linear 2
layers.append(dict( layers.append(
dict(
prev_op=module.mlp.gelu_impl, prev_op=module.mlp.gelu_impl,
layers=[module.mlp.dense_4h_to_h], layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'], inp=input_feat["mlp.dense_4h_to_h"],
)) )
)
return layers return layers
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention from transformers.models.falcon.modeling_falcon import (
FalconDecoderLayer as OldFalconDecoderLayer,
FalconForCausalLM,
FalconAttention,
)
class FalconAWQForCausalLM(BaseAWQForCausalLM): class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer" layer_type = "FalconDecoderLayer"
...@@ -22,7 +27,7 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -22,7 +27,7 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
is_scalable=True, is_scalable=True,
scale_name="mlp.act", scale_name="mlp.act",
scale_layer=module.mlp.act, scale_layer=module.mlp.act,
scale_shape=module.mlp.dense_h_to_4h.out_features scale_shape=module.mlp.dense_h_to_4h.out_features,
) )
@staticmethod @staticmethod
...@@ -30,45 +35,58 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -30,45 +35,58 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device) model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: OldFalconDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(
module: OldFalconDecoderLayer, input_feat, module_kwargs
):
layers = [] layers = []
# Falcon 7B (older architecture) # Falcon 7B (older architecture)
if module.config.num_attention_heads == 71: if module.config.num_attention_heads == 71:
# linear 1 + attention # linear 1 + attention
layers.append(dict( layers.append(
dict(
prev_op=module.input_layernorm, prev_op=module.input_layernorm,
layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value], layers=[
inp=input_feat['self_attention.query_key_value'], module.mlp.dense_h_to_4h,
module.self_attention.query_key_value,
],
inp=input_feat["self_attention.query_key_value"],
module2inspect=module, module2inspect=module,
kwargs=module_kwargs, kwargs=module_kwargs,
)) )
)
# Falcon 40B (newer architecture) # Falcon 40B (newer architecture)
else: else:
# linear 1 + attention # linear 1 + attention
layers.append(dict( layers.append(
dict(
prev_op=module.ln_attn, prev_op=module.ln_attn,
layers=[module.self_attention.query_key_value], layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'], inp=input_feat["self_attention.query_key_value"],
module2inspect=module, module2inspect=module,
kwargs=module_kwargs, kwargs=module_kwargs,
)) )
)
# linear 2 # linear 2
layers.append(dict( layers.append(
dict(
prev_op=module.ln_mlp, prev_op=module.ln_mlp,
layers=[module.mlp.dense_h_to_4h], layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'], inp=input_feat["mlp.dense_h_to_4h"],
module2inspect=module, module2inspect=module,
kwargs=module_kwargs, kwargs=module_kwargs,
)) )
)
return layers return layers
from awq.modules.fused.model import FalconModel from awq.modules.fused.model import FalconModel
from awq.modules.fused.block import FalconDecoderLayer from awq.modules.fused.block import FalconDecoderLayer
class FalconFuser: class FalconFuser:
def __init__(self, model: FalconForCausalLM): def __init__(self, model: FalconForCausalLM):
self.model = model self.model = model
...@@ -89,19 +107,21 @@ class FalconFuser: ...@@ -89,19 +107,21 @@ class FalconFuser:
ln_mlp = module.ln_mlp ln_mlp = module.ln_mlp
new_decoder_arch = True new_decoder_arch = True
blocks.append(FalconDecoderLayer( blocks.append(
FalconDecoderLayer(
hidden_size=module.config.hidden_size, hidden_size=module.config.hidden_size,
n_heads=module.config.num_attention_heads, n_heads=module.config.num_attention_heads,
qkv_layer=module.self_attention.query_key_value, qkv_layer=module.self_attention.query_key_value,
o_proj=module.self_attention.dense, o_proj=module.self_attention.dense,
mlp=module.mlp, mlp=module.mlp,
dev=next(iter(module.state_dict().values())).device, dev=next(iter(module.state_dict().values())).device,
max_seq_len=self.model.config.max_new_tokens, max_seq_len=self.model.config.max_seq_len,
input_layernorm=input_layernorm, input_layernorm=input_layernorm,
ln_attn=ln_attn, ln_attn=ln_attn,
ln_mlp=ln_mlp, ln_mlp=ln_mlp,
new_decoder_arch=new_decoder_arch new_decoder_arch=new_decoder_arch,
)) )
)
self.model.transformer = FalconModel( self.model.transformer = FalconModel(
self.model.config.vocab_size, self.model.config.vocab_size,
......
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM, GPTBigCodeBlock as OldGptBigCodeBlock from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeForCausalLM,
GPTBigCodeBlock as OldGptBigCodeBlock,
)
class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM): class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTBigCodeBlock" layer_type = "GPTBigCodeBlock"
max_new_tokens_key = "n_positions" max_seq_len_key = "n_positions"
@staticmethod @staticmethod
def get_model_layers(model: GPTBigCodeForCausalLM): def get_model_layers(model: GPTBigCodeForCausalLM):
...@@ -15,7 +19,7 @@ class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM): ...@@ -15,7 +19,7 @@ class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM):
is_scalable=True, is_scalable=True,
scale_name="mlp.act", scale_name="mlp.act",
scale_layer=module.mlp.act, scale_layer=module.mlp.act,
scale_shape=module.mlp.c_fc.out_features scale_shape=module.mlp.c_fc.out_features,
) )
@staticmethod @staticmethod
...@@ -25,31 +29,37 @@ class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM): ...@@ -25,31 +29,37 @@ class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM):
model.transformer.drop = model.transformer.drop.to(device) model.transformer.drop = model.transformer.drop.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module:OldGptBigCodeBlock, input_feat, module_kwargs): def get_layers_for_scaling(module: OldGptBigCodeBlock, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
dict(
prev_op=module.ln_1, prev_op=module.ln_1,
layers=[module.attn.c_attn], layers=[module.attn.c_attn],
inp=input_feat['attn.c_attn'], inp=input_feat["attn.c_attn"],
module2inspect=module.attn, module2inspect=module.attn,
kwargs=module_kwargs kwargs=module_kwargs,
)) )
)
# linear 1 # linear 1
layers.append(dict( layers.append(
dict(
prev_op=module.ln_2, prev_op=module.ln_2,
layers=[module.mlp.c_fc], layers=[module.mlp.c_fc],
inp=input_feat['mlp.c_fc'], inp=input_feat["mlp.c_fc"],
module2inspect=module.mlp module2inspect=module.mlp,
)) )
)
# linear 2 # linear 2
layers.append(dict( layers.append(
dict(
prev_op=module.mlp.act, prev_op=module.mlp.act,
layers=[module.mlp.c_proj], layers=[module.mlp.c_proj],
inp=input_feat['mlp.c_proj'] inp=input_feat["mlp.c_proj"],
)) )
)
return layers return layers
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer, GPTNeoXForCausalLM from transformers.models.gpt_neox.modeling_gpt_neox import (
GPTNeoXLayer,
GPTNeoXForCausalLM,
)
class GPTNeoXAWQForCausalLM(BaseAWQForCausalLM): class GPTNeoXAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTNeoXDecoderLayer" layer_type = "GPTNeoXDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_seq_len_key = "max_position_embeddings"
@staticmethod @staticmethod
def get_model_layers(model: GPTNeoXForCausalLM): def get_model_layers(model: GPTNeoXForCausalLM):
...@@ -27,11 +31,13 @@ class GPTNeoXAWQForCausalLM(BaseAWQForCausalLM): ...@@ -27,11 +31,13 @@ class GPTNeoXAWQForCausalLM(BaseAWQForCausalLM):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
dict(
prev_op=module.input_layernorm, prev_op=module.input_layernorm,
layers=[module.attention.query_key_value], layers=[module.attention.query_key_value],
inp=input_feat['attention.query_key_value'], inp=input_feat["attention.query_key_value"],
)) )
)
# attention out # attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469 # Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
...@@ -44,17 +50,21 @@ class GPTNeoXAWQForCausalLM(BaseAWQForCausalLM): ...@@ -44,17 +50,21 @@ class GPTNeoXAWQForCausalLM(BaseAWQForCausalLM):
""" """
# linear 1 # linear 1
layers.append(dict( layers.append(
dict(
prev_op=module.post_attention_layernorm, prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h], layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'], inp=input_feat["mlp.dense_h_to_4h"],
)) )
)
# linear 2 # linear 2
layers.append(dict( layers.append(
dict(
prev_op=module.mlp.act, prev_op=module.mlp.act,
layers=[module.mlp.dense_4h_to_h], layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'], inp=input_feat["mlp.dense_4h_to_h"],
)) )
)
return layers return layers
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock
class GPTJAWQForCausalLM(BaseAWQForCausalLM): class GPTJAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTJBlock" layer_type = "GPTJBlock"
max_new_tokens_key = "n_positions" max_seq_len_key = "n_positions"
@staticmethod @staticmethod
def get_model_layers(model: GPTJForCausalLM): def get_model_layers(model: GPTJForCausalLM):
...@@ -15,7 +16,7 @@ class GPTJAWQForCausalLM(BaseAWQForCausalLM): ...@@ -15,7 +16,7 @@ class GPTJAWQForCausalLM(BaseAWQForCausalLM):
is_scalable=True, is_scalable=True,
scale_name="mlp.act", scale_name="mlp.act",
scale_layer=module.mlp.act, scale_layer=module.mlp.act,
scale_shape=module.mlp.fc_in.out_features scale_shape=module.mlp.fc_in.out_features,
) )
@staticmethod @staticmethod
...@@ -27,27 +28,37 @@ class GPTJAWQForCausalLM(BaseAWQForCausalLM): ...@@ -27,27 +28,37 @@ class GPTJAWQForCausalLM(BaseAWQForCausalLM):
layers = [] layers = []
# attention input + linear 1 # attention input + linear 1
layers.append(dict( layers.append(
dict(
prev_op=module.ln_1, prev_op=module.ln_1,
layers=[module.attn.q_proj, layers=[
module.attn.k_proj, module.attn.v_proj, module.mlp.fc_in], module.attn.q_proj,
inp=input_feat['attn.q_proj'], module.attn.k_proj,
module.attn.v_proj,
module.mlp.fc_in,
],
inp=input_feat["attn.q_proj"],
module2inspect=module, module2inspect=module,
kwargs=module_kwargs kwargs=module_kwargs,
)) )
)
# attention out # attention out
layers.append(dict( layers.append(
dict(
prev_op=module.attn.v_proj, prev_op=module.attn.v_proj,
layers=[module.attn.out_proj], layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj'], inp=input_feat["attn.out_proj"],
)) )
)
# linear 2 # linear 2
layers.append(dict( layers.append(
dict(
prev_op=module.mlp.act, prev_op=module.mlp.act,
layers=[module.mlp.fc_out], layers=[module.mlp.fc_out],
inp=input_feat['mlp.fc_out'], inp=input_feat["mlp.fc_out"],
)) )
)
return layers return layers
...@@ -6,13 +6,14 @@ from awq.modules.fused.block import LlamaLikeBlock ...@@ -6,13 +6,14 @@ from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer, LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM LlamaForCausalLM as OldLlamaForCausalLM,
) )
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlamaAWQForCausalLM(BaseAWQForCausalLM): class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer" layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_seq_len_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model: OldLlamaForCausalLM): def fuse_layers(model: OldLlamaForCausalLM):
...@@ -25,9 +26,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -25,9 +26,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def get_act_for_scaling(module: OldLlamaDecoderLayer): def get_act_for_scaling(module: OldLlamaDecoderLayer):
return dict( return dict(is_scalable=False)
is_scalable=False
)
@staticmethod @staticmethod
def move_embed(model: OldLlamaForCausalLM, device: str): def move_embed(model: OldLlamaForCausalLM, device: str):
...@@ -38,37 +37,49 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -38,37 +37,49 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
dict(
prev_op=module.input_layernorm, prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj, layers=[
module.self_attn.k_proj, module.self_attn.v_proj], module.self_attn.q_proj,
inp=input_feat['self_attn.q_proj'], module.self_attn.k_proj,
module2inspect=module.self_attn, kwargs=module_kwargs, module.self_attn.v_proj,
)) ],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out # attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict( layers.append(
dict(
prev_op=module.self_attn.v_proj, prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj], layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'], inp=input_feat["self_attn.o_proj"],
)) )
)
# linear 1 # linear 1
layers.append(dict( layers.append(
dict(
prev_op=module.post_attention_layernorm, prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj], layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'], inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp, module2inspect=module.mlp,
)) )
)
# linear 2 # linear 2
layers.append(dict( layers.append(
dict(
prev_op=module.mlp.up_proj, prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj], layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'], inp=input_feat["mlp.down_proj"],
)) )
)
return layers return layers
...@@ -78,8 +89,9 @@ class LlamaFuser: ...@@ -78,8 +89,9 @@ class LlamaFuser:
self.model = model self.model = model
self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [ self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules() (name, module)
if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower() for name, module in self.model.named_modules()
if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower()
] ]
def fuse_transformer(self): def fuse_transformer(self):
...@@ -92,17 +104,17 @@ class LlamaFuser: ...@@ -92,17 +104,17 @@ class LlamaFuser:
module, module,
module.self_attn.q_proj, module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj,
) )
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight, module.input_layernorm.variance_epsilon
module.input_layernorm.variance_epsilon
) )
norm_2 = FasterTransformerRMSNorm( norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight, module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon module.post_attention_layernorm.variance_epsilon,
) )
blocks.append(LlamaLikeBlock( blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size, hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads, n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads, n_kv_heads=self.model.config.num_key_value_heads,
...@@ -112,9 +124,10 @@ class LlamaFuser: ...@@ -112,9 +124,10 @@ class LlamaFuser:
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
max_seq_len=self.model.config.max_new_tokens, max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta rope_theta=self.model.config.rope_theta,
)) )
)
self.model.model = LlamaLikeModel( self.model.model = LlamaLikeModel(
self.model.config.vocab_size, self.model.config.vocab_size,
......
...@@ -7,12 +7,15 @@ from awq.modules.fused.model import LlamaLikeModel ...@@ -7,12 +7,15 @@ from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer, LlamaDecoderLayer as OldLlamaDecoderLayer,
) )
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration as OldLlavaForConditionalGeneration from transformers.models.llava.modeling_llava import (
LlavaForConditionalGeneration as OldLlavaForConditionalGeneration,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlavaAWQForCausalLM(BaseAWQForCausalLM): class LlavaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer" layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_seq_len_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model: OldLlavaForConditionalGeneration): def fuse_layers(model: OldLlavaForConditionalGeneration):
...@@ -25,50 +28,62 @@ class LlavaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -25,50 +28,62 @@ class LlavaAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def get_act_for_scaling(module: OldLlamaDecoderLayer): def get_act_for_scaling(module: OldLlamaDecoderLayer):
return dict( return dict(is_scalable=False)
is_scalable=False
)
@staticmethod @staticmethod
def move_embed(model: OldLlavaForConditionalGeneration, device: str): def move_embed(model: OldLlavaForConditionalGeneration, device: str):
model.language_model.model.embed_tokens = model.get_input_embeddings().to(device) model.language_model.model.embed_tokens = model.get_input_embeddings().to(
device
)
@staticmethod @staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
dict(
prev_op=module.input_layernorm, prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj, layers=[
module.self_attn.k_proj, module.self_attn.v_proj], module.self_attn.q_proj,
inp=input_feat['self_attn.q_proj'], module.self_attn.k_proj,
module2inspect=module.self_attn, kwargs=module_kwargs, module.self_attn.v_proj,
)) ],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out # attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict( layers.append(
dict(
prev_op=module.self_attn.v_proj, prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj], layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'], inp=input_feat["self_attn.o_proj"],
)) )
)
# linear 1 # linear 1
layers.append(dict( layers.append(
dict(
prev_op=module.post_attention_layernorm, prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj], layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'], inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp, module2inspect=module.mlp,
)) )
)
# linear 2 # linear 2
layers.append(dict( layers.append(
dict(
prev_op=module.mlp.up_proj, prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj], layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'], inp=input_feat["mlp.down_proj"],
)) )
)
return layers return layers
...@@ -78,8 +93,9 @@ class LlavaFuser: ...@@ -78,8 +93,9 @@ class LlavaFuser:
self.model = model.language_model self.model = model.language_model
self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [ self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules() (name, module)
if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower() for name, module in self.model.named_modules()
if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower()
] ]
def fuse_transformer(self): def fuse_transformer(self):
...@@ -92,17 +108,17 @@ class LlavaFuser: ...@@ -92,17 +108,17 @@ class LlavaFuser:
module, module,
module.self_attn.q_proj, module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj,
) )
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight, module.input_layernorm.variance_epsilon
module.input_layernorm.variance_epsilon
) )
norm_2 = FasterTransformerRMSNorm( norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight, module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon module.post_attention_layernorm.variance_epsilon,
) )
blocks.append(LlamaLikeBlock( blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size, hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads, n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads, n_kv_heads=self.model.config.num_key_value_heads,
...@@ -112,8 +128,9 @@ class LlavaFuser: ...@@ -112,8 +128,9 @@ class LlavaFuser:
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
max_seq_len=self.model.config.max_new_tokens max_seq_len=self.model.config.max_seq_len,
)) )
)
self.model.model = LlamaLikeModel( self.model.model = LlamaLikeModel(
self.model.config.vocab_size, self.model.config.vocab_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment