"...text-generation-inference.git" did not exist on "ab7ccf5bc3c84e07d0faf0d950421fcdc29743b5"
Unverified Commit 72f954ce authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #62 from casper-hansen/refactor_quant

Refactor quantization code
parents a5e8b048 d6cf1442
import os
import time
import torch
import argparse
from lm_eval import evaluator
from awq import AutoAWQForCausalLM
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from transformers import AutoTokenizer, GenerationConfig
def load_search_result_into_memory(model, search_path):
awq_results = torch.load(search_path, map_location="cpu")
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
def run_search(model_path, dump_path, quant_config):
"""
Step 1/2: Search the pile for an optimal scaling factor.
"""
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(tokenizer, quant_config=quant_config, run_search=True, run_quant=False)
# Save search results
model.save_quantized(dump_path)
# Save tokenizer
tokenizer.save_pretrained(dump_path)
def run_quant(model_path, search_path, dump_path, quant_config):
"""
Step 2/2: Use the search results to quantize model weights
"""
# Load model and search results
model = AutoAWQForCausalLM.from_pretrained(model_path)
load_search_result_into_memory(model.model, search_path)
# Run actual weight quantization
model.quantize(quant_config=quant_config, run_search=False, run_quant=True)
# Save quantized model
model.save_quantized(dump_path)
def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot, task_use_pretrained):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
if task_use_pretrained:
model = AutoAWQForCausalLM.from_pretrained(model_path)
else:
model = AutoAWQForCausalLM.from_quantized(model_path, quant_file)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Load adapter
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size)
# Evaluate perplexity of quantized model
results = evaluator.simple_evaluate(
model=lm_eval_model,
tasks=tasks.split(','),
batch_size=task_batch_size,
no_cache=True,
num_fewshot=task_n_shot,
)
print(evaluator.make_table(results))
@torch.inference_mode()
def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, batch_size=1, disable_fused_layers=False):
def _timer(func):
start = time.time()
out = func()
return out, time.time() - start
def _warmup(device:str):
warm_up = torch.randn((4096,4096)).to(device)
torch.mm(warm_up,warm_up)
if quant_file:
fuse_layers = False if disable_fused_layers else True
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=fuse_layers))
else:
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_pretrained(model_path))
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
_warmup(device)
# Generate random inputs
n_context = n_context - n_generate
ids = torch.randint(0, tokenizer.vocab_size, (batch_size, n_context)).cuda()
# Context stage
_, context_time = _timer(lambda: model.generate(
ids,
generation_config=GenerationConfig(
max_new_tokens=0,
min_new_tokens=0,
use_cache=True
)
))
# Generation stage
_, generation_time = _timer(lambda: model.generate(
ids,
generation_config=GenerationConfig(
max_new_tokens=n_context,
min_new_tokens=n_context,
forced_eos_token_id=-100,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=-100,
use_cache=True
)
))
# Prints
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
context_tokens_per_second = n_context / context_time * batch_size
context_ms_per_token = (context_time*1000) / n_context / batch_size
inference_tokens_per_second = n_generate / generation_time * batch_size
inference_ms_per_token = (generation_time*1000) / n_generate / batch_size
print(f"[=] Model summary: {model_path} [=]")
print(f"[*] Load time: {load_time:.2f} seconds")
print(f"[*] Context speed: {context_tokens_per_second:.2f} tokens/second ({context_ms_per_token:.2f} ms/token)")
print(f"[*] Generation speed: {inference_tokens_per_second:.2f} tokens/second ({inference_ms_per_token:.2f} ms/token)")
print(f"[*] VRAM: {memory_used:.2f} MB")
if __name__ == '__main__':
"""
- Run AWQ search and save result:
python -m awq.entry --entry_type search --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq
- Run AWQ to save the real quantized weights at the quant_path:
python -m awq.entry --entry_type quant --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq/awq_model_search_result.pt --quant_path vicuna-7b-v1.5-awq
- Run perplexity of quantized model:
python -m awq.entry --entry_type eval --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
- Run perplexity unquantized FP16 model:
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
- Run a speedtest to benchmark the quantized model:
python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt --n_generate 128 --n_context 256
- Run a speedtest to benchmark the unquantized FP16 model:
python -m awq.entry --entry_type speed --model_path lmsys/vicuna-7b-v1.5 --n_generate 128 --n_context 256
"""
parser = argparse.ArgumentParser()
parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|eval|speed)')
parser.add_argument('--model_path', type=str, help='Path to hf model')
parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
parser.add_argument('--quant_path', type=str, help='Path to save AWQ model to directory')
parser.add_argument('--quant_file', type=str, help='Path to quantized AWQ model file')
parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
parser.add_argument('--w_bit', type=int, default=4)
parser.add_argument('--q_group_size', type=int, default=128)
parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. '
'Separate tasks by comma for multiple tasks.'
'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md')
parser.add_argument("--task_use_pretrained", default=False, action='store_true',
help="Pass '--task_use_pretrained' to use a pretrained model running FP16")
parser.add_argument('--task_batch_size', type=int, default=1)
parser.add_argument('--task_n_shot', type=int, default=0)
parser.add_argument('--n_generate', type=int, default=128)
parser.add_argument('--n_context', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument("--disable_fused_layers", default=False, action='store_true',
help="Pass '--disable_fused_layers' to disable fused layers")
args = parser.parse_args()
quant_config = { "zero_point": True, "q_group_size": args.q_group_size, "w_bit": args.w_bit }
if args.entry_type == 'search':
run_search(args.model_path, args.search_path, quant_config)
elif args.entry_type == 'quant':
run_quant(args.model_path, args.search_path, args.quant_path, quant_config)
elif args.entry_type == 'eval':
run_eval(args.model_path, args.quant_file, args.device,
args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
elif args.entry_type == 'speed':
run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context, args.batch_size, args.disable_fused_layers)
else:
raise Exception('--entry_type must be one of (search|quant|eval|speed)')
...@@ -2,26 +2,19 @@ import os ...@@ -2,26 +2,19 @@ import os
import gc import gc
import json import json
import torch import torch
import logging
import functools
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import List, Union from typing import List, Union
from collections import defaultdict
from safetensors.torch import save_file from safetensors.torch import save_file
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.utils import simple_dispatch_model from awq.utils.utils import simple_dispatch_model
from awq.utils.calib_data import get_calib_dataset
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip from awq.utils.module import get_named_linears, set_op_by_name
from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM(nn.Module): class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, quant_config): def __init__(self, model, model_type, is_quantized, quant_config):
...@@ -43,238 +36,64 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -43,238 +36,64 @@ class BaseAWQForCausalLM(nn.Module):
return self.model.generate(*args, **kwargs) return self.model.generate(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512, def quantize(self, tokenizer=None, quant_config={},
auto_scale=True, mse_range=True, run_search=True, run_quant=True, calib_data: Union[str, List[str]]="pileval",
calib_data: Union[str, List[str]]="pileval", split="train", split="train", text_column="text"):
text_column="text"):
self.quant_config = quant_config self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"] quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
if run_search: quantizer = AwqQuantizer(
self.search_result = self._awq_search( self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen, quant_config["version"], calib_data, split, text_column
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data, )
split=split, text_column=text_column quantizer.quantize()
) self.is_quantized = True
if run_quant:
self._awq_quant()
self.is_quantized = True
@staticmethod @staticmethod
def fuse_layers(model, quant_config): def fuse_layers(model, quant_config):
pass pass
def _awq_quant(self):
assert self.quant_config["zero_point"], "We only support zero_point quantization now."
layers = self.get_model_layers(self.model)
# Run AWQ quantization def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
for i in tqdm(range(len(layers)), desc="AWQ Quantization"): save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
layer = layers[i]
named_linears = get_named_linears(layer)
self._scale_activations(self, layer)
for name, module in named_linears.items(): # Save model
module.cuda() class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
module.weight.data, scales, zeros = pseudo_quantize_tensor( # Save model files with empty state dict
module.weight.data, self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
get_scale_zp=True,
w_bit=self.quant_config["w_bit"],
q_group_size=self.quant_config["q_group_size"]
)
if self.quant_config["version"] == 'GEMM': # Remove empty state dict
scales = scales.t().contiguous() os.remove(f'{save_dir}/pytorch_model.bin')
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM
elif self.quant_config["version"] == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module,
self.quant_config['w_bit'],
self.quant_config['q_group_size'],
False,
scales,
zeros
)
module.cpu() # model_name has no extension, add it when saving state_dict
q_linear.to(next(layer.parameters()).device) model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval",
split="train", text_column="text"):
layers = self.get_model_layers(self.model)
samples = get_calib_dataset( # shard checkpoint into chunks (10GB default)
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen, shards, index = shard_checkpoint(
split=split, text_column=text_column self.model.state_dict(),
max_shard_size=shard_size,
weights_name=model_name
) )
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
layers[0] = layers[0].cuda()
self.move_embed(self.model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, hijacked_inputs, **kwargs):
inps.append(hijacked_inputs)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
layers[0] = Catcher(layers[0])
try:
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
del samples
layers[0] = layers[0].module # restore
inps = inps[0]
layers[0] = layers[0].cpu()
self.move_embed(self.model, "cpu")
gc.collect()
torch.cuda.empty_cache()
awq_results = {
"scale": [],
"clip": [],
}
# Run AWQ search layer by layer for shard_file, shard in shards.items():
for i in tqdm(range(len(layers)), desc="AWQ Search"): if safetensors:
layer = layers[i] # safetensors must be in the same memory, so we duplicate and use contiguous memory
layer = layer.cuda() shard = {k: v.clone().contiguous() for k, v in shard.items()}
named_linears = get_named_linears(layer) save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
feat_dict=input_feat)))
inps = inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
inps = layer(inps, **layer_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
# Clear GPU memory
torch.cuda.empty_cache()
if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block(
self,
layer,
layer_kwargs,
quant_config=quant_config,
input_feat=input_feat,
)
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
# Clear GPU memory
torch.cuda.empty_cache()
if mse_range:
clip_list = auto_clip_block(
layer,
quant_config=quant_config,
input_feat=input_feat
)
apply_clip(layer, clip_list)
# append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
layer = layer.cpu()
# Haotian: check activation replacement
del input_feat
gc.collect()
torch.cuda.empty_cache()
return awq_results
def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
def _save_files(save_dir, model_name='', search_result=None):
class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
# Save model files with empty state dict
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
# Remove empty state dict
os.remove(f'{save_dir}/pytorch_model.bin')
if search_result is not None:
torch.save(search_result, f'{save_dir}/{model_name}')
else: else:
# model_name has no extension, add it when saving state_dict torch.save(shard, os.path.join(save_dir, shard_file))
model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
# shard checkpoint into chunks (10GB default) # save shard index
shards, index = shard_checkpoint( if index is not None:
self.model.state_dict(), with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
max_shard_size=shard_size, file.write(json.dumps(index, indent=4))
weights_name=model_name
)
for shard_file, shard in shards.items(): # Save config
if safetensors: with open(f'{save_dir}/quant_config.json', 'w+') as file:
# safetensors must be in the same memory, so we duplicate and use contiguous memory file.write(json.dumps(self.quant_config, indent=4))
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(save_dir, shard_file))
# save shard index
if index is not None:
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
file.write(json.dumps(index, indent=4))
# Save config
with open(f'{save_dir}/quant_config.json', 'w+') as file:
file.write(json.dumps(self.quant_config, indent=4))
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
# Save model
if self.search_result is None or self.is_quantized:
_save_files(save_dir, '', search_result=None)
else:
model_name = 'awq_model_search_result.pt'
_save_files(save_dir, model_name, self.search_result)
@classmethod @classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
......
...@@ -27,10 +27,12 @@ class OptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -27,10 +27,12 @@ class OptAWQForCausalLM(BaseAWQForCausalLM):
# attention input # attention input
layers.append(dict( layers.append(dict(
prev_op=module.self_attn_layer_norm, prev_op=module.self_attn_layer_norm,
layers=[module.self_attn.q_proj, layers=[
module.self_attn.k_proj, module.self_attn.v_proj], module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'], inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs, module2inspect=module.self_attn,
kwargs=module_kwargs,
)) ))
# attention out # attention out
......
import torch
import torch.nn as nn
from .quantizer import pseudo_quantize_tensor
import gc
__all__ = ["auto_clip_block"]
# weight quantization
@torch.no_grad()
def auto_clip_layer(w,
input_feat,
quant_config,
n_grid=20,
max_shrink=0.5,
n_sample_token=512):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = quant_config["q_group_size"] if quant_config["q_group_size"] > 0 else w.shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size)
oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 # prevent OOM
assert w.shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []
for i_b in range(w.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group
for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = pseudo_quantize_tensor(cur_w, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"])
cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)
best_max_val = torch.cat(best_max_val_all, dim=0)
del input_feat
del org_out
gc.collect()
torch.cuda.empty_cache()
return best_max_val.squeeze(1)
@torch.no_grad()
def auto_clip_block(module,
quant_config,
input_feat):
named_linears = {name: m for name,
m in module.named_modules() if isinstance(m, nn.Linear)}
clip_list = []
for name in named_linears:
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in ["q_", "k_", "query", "key", "Wqkv"]]):
continue
named_linears[name].cuda()
max_val = auto_clip_layer(
named_linears[name].weight, input_feat[name], quant_config=quant_config)
clip_list.append((name, max_val))
named_linears[name].cpu()
return clip_list
@torch.no_grad()
def apply_clip(module, clip_list):
from ..utils.module import get_op_by_name
for name, max_val in clip_list:
layer = get_op_by_name(module, name)
layer.cuda()
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
layer.cpu()
import torch import torch
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from awq.utils.utils import clear_memory
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.scale import apply_scale, apply_clip
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column) -> None:
self.awq_model = awq_model
self.model = model
self.tokenizer = tokenizer
self.w_bit = w_bit
self.group_size = group_size
self.version = version
self.calib_data = calib_data
self.split = split
self.text_column = text_column
self.modules, self.module_kwargs, self.inps = self.init_quant()
# core quantization method (simulated quantization) def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
def pseudo_quantize_tensor(w, w_bit=4, org_w_shape = w.shape
zero_point=True, if self.group_size > 0:
q_group_size=-1, assert org_w_shape[-1] % self.group_size == 0
inplace=False, w = w.reshape(-1, self.group_size)
get_scale_zp=False assert w.dim() == 2
):
org_w_shape = w.shape # zero point quantization
if q_group_size > 0:
assert org_w_shape[-1] % q_group_size == 0
w = w.reshape(-1, q_group_size)
assert w.dim() == 2
if zero_point:
max_val = w.amax(dim=1, keepdim=True) max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True) min_val = w.amin(dim=1, keepdim=True)
max_int = 2 ** w_bit - 1 max_int = 2 ** self.w_bit - 1
min_int = 0 min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
else: # we actually never used this
assert min_val is None assert torch.isnan(scales).sum() == 0
max_val = w.abs().amax(dim=1, keepdim=True) assert torch.isnan(w).sum() == 0
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (w_bit - 1) - 1 w = (torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros) * scales
min_int = - 2 ** (w_bit - 1) assert torch.isnan(w).sum() == 0
scales = max_val / max_int
zeros = 0 w = w.reshape(org_w_shape)
assert torch.isnan(scales).sum() == 0 if get_scale_zp:
assert torch.isnan(w).sum() == 0 return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
else:
if inplace: return w
((w.div_(scales).round_().add_(zeros)).clamp_(
min_int, max_int).sub_(zeros)).mul_(scales) def quantize(self):
else: for i in tqdm(range(len(self.modules)), desc="AWQ"):
w = (torch.clamp(torch.round(w / scales) + # [STEP 1]: Get layer, extract linear modules, extract input features
zeros, min_int, max_int) - zeros) * scales self.modules[i] = self.modules[i].cuda()
assert torch.isnan(w).sum() == 0 named_linears = get_named_linears(self.modules[i])
input_feat = self._get_input_feat(self.modules[i], named_linears)
w = w.reshape(org_w_shape) clear_memory()
if get_scale_zp: # [STEP 2]: Compute and apply scale list
return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1) module_config: list[dict] = self.awq_model.get_layers_for_scaling(
else: self.modules[i], input_feat, self.module_kwargs
return w )
scales_list = [self._search_best_scale(self.modules[i], **layer) for layer in module_config]
apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
scales_list = append_str_prefix(scales_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 3]: Compute and apply clipping list
clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat)
apply_clip(self.modules[i], clip_list)
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 4]: Quantize weights
self._apply_quant(self.modules[i], named_linears)
clear_memory()
def _apply_quant(self, module, named_linears: dict[str, nn.Linear]):
for name, linear_layer in named_linears.items():
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer = linear_layer.cuda().half()
linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
linear_layer.weight.data,
get_scale_zp=True
)
if self.version == 'GEMM':
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM
elif self.version == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
linear=linear_layer,
w_bit=self.w_bit,
group_size=self.group_size,
init_only=False,
scales=scales,
zeros=zeros
)
linear_layer.cpu()
q_linear.to(next(module.parameters()).device)
set_op_by_name(module, name, q_linear)
clear_memory()
@torch.no_grad()
def _search_best_scale(self, module, prev_op, layers: list[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}):
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
if "use_cache" in kwargs:
kwargs.pop("use_cache")
# Put x on the right device
inp = inp.to(next(module2inspect.parameters()).device)
# [STEP 1]: Compute maximum of weight
weight = torch.cat([_m.weight for _m in layers], dim=0)
org_shape = weight.shape
weight = weight.view(-1, self.group_size)
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
w_scale = w_scale.view(org_shape)
w_max = w_scale.mean(0)
clear_memory(weight)
# [STEP 2]: Compute maximum of x
x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)
# [STEP 3]: Compute output of module
with torch.no_grad():
fp16_output = module2inspect(inp, **kwargs)
if isinstance(fp16_output, tuple):
fp16_output = fp16_output[0]
# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect,
layers, fp16_output, kwargs
)
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: list[nn.Linear],
fp16_output, kwargs={}):
"""
Compute loss and select best scales
L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Q: weight quantization function | pseudo_quantize_tensor(W * s)
X: inputs from calib dataset | X
W: original weights in FP16 | layer
s: per channel scaling factor | s^-1 * X
"""
n_grid = 20
history = []
best_ratio = -1
best_scales = None
best_error = float('inf')
org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
device = x.device
x_max = x_max.view(-1).to(device)
w_max = w_max.view(-1).to(device)
for ratio in range(n_grid):
# create new scales
ratio = ratio / n_grid
# NOTE: s^-1 * x is fused here, according to paper
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device)
# Q(W * s)
for fc in linears2scale:
fc.weight.mul_(scales_view)
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view
# W * X
int_w_output = module2inspect(x, **kwargs)
if isinstance(int_w_output, tuple):
int_w_output = int_w_output[0]
# compute mean squared error (L2 norm)
loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow
history.append(loss)
if loss < best_error:
best_error = loss
best_ratio = ratio
best_scales = scales.clone()
module2inspect.load_state_dict(org_sd)
if best_ratio == -1:
logging.debug(history)
raise Exception
assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach().cpu()
@torch.no_grad()
def _search_best_clip(self, layer, named_linears, input_feat):
clip_list = []
avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
for name in named_linears:
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in avoid_clipping]):
continue
named_linears[name].cuda()
max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name])
clip_list.append((name, max_val))
named_linears[name].cpu()
return clip_list
@torch.no_grad()
def _compute_best_clip(self, w: torch.Tensor, input_feat: torch.Tensor, n_grid=20, max_shrink=0.5, n_sample_token=512):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self.group_size if self.group_size > 0 else w.shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size)
oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 # prevent OOM
assert w.shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []
for i_b in range(w.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group
for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = self.pseudo_quantize_tensor(cur_w)
cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)
best_max_val = torch.cat(best_max_val_all, dim=0)
clear_memory(input_feat)
clear_memory(org_out)
return best_max_val.squeeze(1)
def init_quant(self, n_samples=128, seqlen=512):
modules = self.awq_model.get_model_layers(self.model)
samples = get_calib_dataset(
data=self.calib_data, tokenizer=self.tokenizer, n_samples=n_samples, block_size=seqlen,
split=self.split, text_column=self.text_column
)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
modules[0] = modules[0].cuda()
self.awq_model.move_embed(self.model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, hijacked_inputs, **kwargs):
inps.append(hijacked_inputs)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
modules[0] = Catcher(modules[0])
try:
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
del samples
modules[0] = modules[0].module # restore
inps = inps[0]
modules[0] = modules[0].cpu()
self.awq_model.move_embed(self.model, "cpu")
clear_memory()
return modules, layer_kwargs, inps
def _get_input_feat(self, layer, named_linears):
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
feat_dict=input_feat)))
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
self.inps = layer(self.inps, **self.module_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
return input_feat
import gc
import torch import torch
import torch.nn as nn import torch.nn as nn
import logging
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu from typing import Tuple
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import NewGELUActivation
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name from transformers.activations import NewGELUActivation
from awq.utils.module import get_op_by_name, set_op_by_name
__all__ = ["auto_scale_block", "apply_scale"] from transformers.models.bloom.modeling_bloom import BloomGelu
from transformers.models.llama.modeling_llama import LlamaRMSNorm
allowed_norms = [nn.LayerNorm, LlamaRMSNorm]
allowed_act_fns = [nn.GELU, BloomGelu, NewGELUActivation]
@torch.no_grad() @torch.no_grad()
def get_weight_scale(weight, q_group_size=-1): def apply_clip(module, clip_list: Tuple[str, torch.Tensor]):
org_shape = weight.shape for name, max_val in clip_list:
if q_group_size > 0: layer: nn.Linear = get_op_by_name(module, name)
weight = weight.view(-1, q_group_size) layer.cuda()
scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) max_val = max_val.to(layer.weight.device)
scale = scale.view(org_shape) org_shape = layer.weight.shape
scale = scale.mean(0) layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
return scale layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
layer.cpu()
@torch.no_grad() def apply_scale(module, scales_list, input_feat_dict=None):
def get_act_scale(x): for prev_op_name, layer_names, scales in scales_list:
return x.abs().view(-1, x.shape[-1]).mean(0) prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names]
prev_op.cuda()
for layer in layers:
layer.cuda()
scales.cuda()
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
elif any(isinstance(prev_op,t) for t in allowed_norms) \
or 'rmsnorm' in str(prev_op.__class__).lower():
scale_ln_fcs(prev_op, layers, scales)
elif any(isinstance(prev_op,t) for t in allowed_act_fns):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
else:
raise NotImplementedError(
f"prev_op {type(prev_op)} not supported yet!")
# apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None:
for layer_name in layer_names:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
prev_op.cpu()
for layer in layers:
layer.cpu()
scales.cpu()
@torch.no_grad() @torch.no_grad()
def scale_ln_fcs(ln, fcs, scales): def scale_ln_fcs(ln: nn.Linear, fcs: list[nn.Linear], scales: torch.Tensor):
if not isinstance(fcs, list): if not isinstance(fcs, list):
fcs = [fcs] fcs = [fcs]
scales = scales.to(ln.weight.device) scales = scales.to(ln.weight.device)
# debugging start even scales = 1 does not work?
"""
scales = scales * 0
scales = scales + 1
"""
# debugging end
ln.weight.div_(scales) ln.weight.div_(scales)
if hasattr(ln, 'bias') and ln.bias is not None: if hasattr(ln, 'bias') and ln.bias is not None:
ln.bias.div_(scales) ln.bias.div_(scales)
...@@ -56,16 +82,13 @@ def scale_ln_fcs(ln, fcs, scales): ...@@ -56,16 +82,13 @@ def scale_ln_fcs(ln, fcs, scales):
for p in fc.parameters(): for p in fc.parameters():
assert torch.isnan(p).sum() == 0 assert torch.isnan(p).sum() == 0
@torch.no_grad() @torch.no_grad()
def scale_fc_fc(fc1, fc2, scales): def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor):
assert isinstance(fc1, nn.Linear) assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear) assert isinstance(fc2, nn.Linear)
# assert fc1.out_features == fc2.in_features
scales = scales.to(fc1.weight.device) scales = scales.to(fc1.weight.device)
# fc1.weight.div_(scales.view(-1, 1))
fc1.weight[-scales.size(0):].div_(scales.view(-1, 1)) fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
if fc1.bias is not None: if fc1.bias is not None:
fc1.bias.div_(scales.view(-1)) fc1.bias.div_(scales.view(-1))
...@@ -79,141 +102,11 @@ def scale_fc_fc(fc1, fc2, scales): ...@@ -79,141 +102,11 @@ def scale_fc_fc(fc1, fc2, scales):
@torch.no_grad() @torch.no_grad()
def scale_gelu_fc(gelu, fc, scales): def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor):
assert any(isinstance(gelu,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]) assert any(isinstance(gelu,t) for t in allowed_act_fns)
assert isinstance(fc, nn.Linear) assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
for p in fc.parameters(): for p in fc.parameters():
assert torch.isnan(p).sum() == 0 assert torch.isnan(p).sum() == 0
\ No newline at end of file
@torch.no_grad()
def auto_scale_block(awq_model,
module,
module_kwargs,
quant_config,
input_feat):
from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function
if quant_config['w_bit'] is not None:
def w_quantize_func(p): return pseudo_quantize_tensor(p, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"]).detach()
else:
def w_quantize_func(p): return p
if "use_cache" in module_kwargs:
module_kwargs.pop("use_cache")
# find the best scale ratio
def _search_module_scale(block, linears2scale: list, x, kwargs={}):
# w: co, ci
# x: n, ci
weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
w_max = get_weight_scale(
weight, q_group_size=quant_config.get("q_group_size", -1))
# Clear GPU memory
del weight
gc.collect()
torch.cuda.empty_cache()
x = x.to(next(block.parameters()).device)
with torch.no_grad():
org_out = block(x, **kwargs)
if isinstance(org_out, tuple):
org_out = org_out[0]
x_max = get_act_scale(x)
best_error = float('inf')
best_ratio = -1
best_scales = None
n_grid = 20
history = []
org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
for ratio in range(n_grid):
ratio = ratio * 1 / n_grid
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)
).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
fc.weight.data = w_quantize_func(
fc.weight.data) / (scales.view(1, -1))
out = block(x, **kwargs)
if isinstance(out, tuple):
out = out[0]
loss = (org_out - out).float().pow(2).mean().item() # float prevents overflow
history.append(loss)
is_best = loss < best_error
if is_best:
best_error = loss
best_ratio = ratio
best_scales = scales
block.load_state_dict(org_sd)
if best_ratio == -1:
logging.debug(history)
raise Exception
best_scales = best_scales.view(-1)
assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach()
def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
# module2inspect: if given, we will check the output diff of this module instead of layers
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
scales = _search_module_scale(module2inspect, layers, inp, kwargs)
scales = scales.detach().cpu()
# prev_op_name, [layer_name], scale
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)
layers: list[dict] = awq_model.get_layers_for_scaling(
module, input_feat, module_kwargs
)
scales_list = [_auto_get_scale(**layer) for layer in layers]
return scales_list
def apply_scale(module, scales_list, input_feat_dict=None):
for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names]
prev_op.cuda()
for layer in layers:
layer.cuda()
scales.cuda()
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
elif any(isinstance(prev_op,t) for t in [nn.LayerNorm, LlamaRMSNorm]) \
or 'rmsnorm' in str(prev_op.__class__).lower():
scale_ln_fcs(prev_op, layers, scales)
elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
else:
raise NotImplementedError(
f"prev_op {type(prev_op)} not supported yet!")
# apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None:
for layer_name in layer_names:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
prev_op.cpu()
for layer in layers:
layer.cpu()
scales.cpu()
import gc
import torch import torch
import accelerate import accelerate
...@@ -53,3 +54,9 @@ def set_module_name(model, name, value): ...@@ -53,3 +54,9 @@ def set_module_name(model, name, value):
child_name = name child_name = name
setattr(parent, child_name, value) setattr(parent, child_name, value)
def clear_memory(weight=None):
if weight is not None:
del weight
gc.collect()
torch.cuda.empty_cache()
\ No newline at end of file
import argparse
from lm_eval import evaluator
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot, task_use_pretrained):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
if task_use_pretrained:
model = AutoAWQForCausalLM.from_pretrained(model_path)
else:
model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Load adapter
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size)
# Evaluate perplexity of quantized model
results = evaluator.simple_evaluate(
model=lm_eval_model,
tasks=tasks.split(','),
batch_size=task_batch_size,
no_cache=True,
num_fewshot=task_n_shot,
)
print(evaluator.make_table(results))
if __name__ == '__main__':
"""
- Run perplexity of quantized model:
python examples/eval.py --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
- Run perplexity unquantized FP16 model:
python examples/eval.py --use_pretrained --model_path lmsys/vicuna-7b-v1.5
"""
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='Path to hf model')
parser.add_argument('--quant_file', default='', type=str, help='Path to quantized AWQ model file')
parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
parser.add_argument("--use_pretrained", default=False, action='store_true',
help="Pass '--use_pretrained' to use a pretrained model running FP16")
parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. '
'Separate tasks by comma for multiple tasks.'
'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--n_shot', type=int, default=0)
args = parser.parse_args()
run_eval(args.model_path, args.quant_file, args.device,
args.tasks, args.batch_size, args.n_shot, args.use_pretrained)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment