Commit 290f45e7 authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactored pre_quant into base

parent d5be2115
...@@ -7,11 +7,11 @@ import json ...@@ -7,11 +7,11 @@ import json
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model
from accelerate.utils.modeling import get_balanced_memory from accelerate.utils.modeling import get_balanced_memory
from awq.utils.parallel import auto_parallel from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from awq.utils.utils import simple_dispatch_model from awq.utils.utils import simple_dispatch_model
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='path of the hf model') parser.add_argument('--model_path', type=str, help='path of the hf model')
...@@ -65,7 +65,13 @@ q_config = { ...@@ -65,7 +65,13 @@ q_config = {
} }
print("Quantization config:", q_config) print("Quantization config:", q_config)
# build model and tokenizer def get_awq_model(model):
from awq.models import MptAWQForCausalLM
if "mpt" in str(model.__class__).lower():
return MptAWQForCausalLM()
else:
raise NotImplementedError(type(model))
def build_model_and_enc(model_path): def build_model_and_enc(model_path):
if not os.path.exists(model_path): # look into ssd if not os.path.exists(model_path): # look into ssd
...@@ -120,11 +126,9 @@ def build_model_and_enc(model_path): ...@@ -120,11 +126,9 @@ def build_model_and_enc(model_path):
if args.run_awq: if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq" assert args.dump_awq, "Please save the awq results with --dump_awq"
awq_results = run_awq( awq_model = get_awq_model(model)
model, enc, awq_results = awq_model.quantize(model, enc, args.w_bit, q_config)
w_bit=args.w_bit, q_config=q_config,
n_samples=128, seqlen=512,
)
if args.dump_awq: if args.dump_awq:
dirpath = os.path.dirname(args.dump_awq) dirpath = os.path.dirname(args.dump_awq)
os.makedirs(dirpath, exist_ok=True) os.makedirs(dirpath, exist_ok=True)
...@@ -137,7 +141,9 @@ def build_model_and_enc(model_path): ...@@ -137,7 +141,9 @@ def build_model_and_enc(model_path):
if args.load_awq: if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq) print("Loading pre-computed AWQ results from", args.load_awq)
awq_results = torch.load(args.load_awq, map_location="cpu") awq_results = torch.load(args.load_awq, map_location="cpu")
apply_awq(model, awq_results)
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
# weight quantization # weight quantization
if args.w_bit is not None: if args.w_bit is not None:
......
import gc
import tqdm
import torch
import functools
import torch.nn as nn
from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears
class BaseAWQForCausalLM: class BaseAWQForCausalLM:
def quantize(): @torch.no_grad()
def quantize(self, model, tokenizer, w_bit, q_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"):
layers = self.get_model_layers(model)
samples = get_calib_dataset(
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
layers[0] = layers[0].cuda()
self.move_embed(model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps.append(inp)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
layers[0] = Catcher(layers[0])
try:
model(samples.to(next(model.parameters()).device))
except ValueError: # work with early exit
pass pass
del samples
layers[0] = layers[0].module # restore
inps = inps[0]
layers[0] = layers[0].cpu()
self.move_embed(model, "cpu")
gc.collect()
torch.cuda.empty_cache()
awq_results = {
"scale": [],
"clip": [],
}
# solve layer by layer
for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."):
layer = layers[i]
layer = layer.cuda()
named_linears = get_named_linears(layer)
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
feat_dict=input_feat)))
inps = inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
inps = layer(inps, **layer_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
# Clear GPU memory
torch.cuda.empty_cache()
if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block(
self,
layer, layer_kwargs,
w_bit=w_bit, q_config=q_config,
input_feat=input_feat,
)
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")
# Clear GPU memory
torch.cuda.empty_cache()
if mse_range:
clip_list = auto_clip_block(layer,
w_bit=w_bit, q_config=q_config,
input_feat=input_feat,)
apply_clip(layer, clip_list)
# append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")
layer = layer.cpu()
# Haotian: check activation replacement
del input_feat
gc.collect()
torch.cuda.empty_cache()
return awq_results
def save_quantized(): def save_quantized():
pass pass
......
...@@ -3,10 +3,10 @@ from .base import BaseAWQForCausalLM ...@@ -3,10 +3,10 @@ from .base import BaseAWQForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
def get_model_layers(model): def get_model_layers(self, model):
return model.transformer.blocks return model.transformer.blocks
def get_layers_for_scaling(module, input_feat, module_kwargs): def get_layers_for_scaling(self, module, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
...@@ -42,13 +42,13 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -42,13 +42,13 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
return layers return layers
def get_act_for_scaling(module): def get_act_for_scaling(self, module):
return dict( return dict(
scale_name="ffn.act", scale_name="ffn.act",
scale_layer=module.ffn.act, scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features scale_shape=module.ffn.up_proj.out_features
) )
def move_embed(model, device): def move_embed(self, model, device):
model.transformer.wte = model.transformer.wte.to(device) model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device) model.transformer.emb_drop = model.transformer.emb_drop.to(device)
\ No newline at end of file
...@@ -7,8 +7,7 @@ from transformers.models.opt.modeling_opt import OPTDecoderLayer ...@@ -7,8 +7,7 @@ from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from .qmodule import ScaledActivation from .qmodule import ScaledActivation
from ..utils.module import get_op_by_name, get_op_name, set_op_by_name from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
from ..models import MptAWQForCausalLM
__all__ = ["auto_scale_block", "apply_scale"] __all__ = ["auto_scale_block", "apply_scale"]
...@@ -90,7 +89,8 @@ def scale_gelu_fc(gelu, fc, scales): ...@@ -90,7 +89,8 @@ def scale_gelu_fc(gelu, fc, scales):
@torch.no_grad() @torch.no_grad()
def auto_scale_block(module, module_kwargs, def auto_scale_block(awq_model,
module, module_kwargs,
w_bit, q_config, w_bit, q_config,
input_feat): input_feat):
from .quantizer import pseudo_quantize_tensor from .quantizer import pseudo_quantize_tensor
...@@ -174,149 +174,10 @@ def auto_scale_block(module, module_kwargs, ...@@ -174,149 +174,10 @@ def auto_scale_block(module, module_kwargs,
# prev_op_name, [layer_name], scale # prev_op_name, [layer_name], scale
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales) return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)
scales_list = [] # return the searched scales layers: list[dict] = awq_model.get_layers_for_scaling(
if isinstance(module, OPTDecoderLayer):
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.self_attn_layer_norm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attn out
scales_list.append(_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.out_proj],
inp=input_feat['self_attn.out_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.final_layer_norm,
layers=[module.fc1],
inp=input_feat['fc1'],
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.fc1,
layers=[module.fc2],
inp=input_feat['fc2'],
))
elif isinstance(module, LlamaDecoderLayer):
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
scales_list.append(_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
elif isinstance(module, BloomBlock):
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module, kwargs=module_kwargs,
))
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module, kwargs=module_kwargs,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.gelu_impl,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
elif "mpt" in str(module.__class__).lower():
layers: list[dict] = MptAWQForCausalLM.get_layers_for_scaling(
module, input_feat, module_kwargs module, input_feat, module_kwargs
) )
layers_scaled = [_auto_get_scale(**layer) for layer in layers] scales_list = [_auto_get_scale(**layer) for layer in layers]
scales_list.extend(layers_scaled)
elif "falcon" in str(module.__class__).lower():
# attn out
# Haotian: TBD: need to handle repeated scales for MQ
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1, as long as it is scaled, everything is screwed up
if "falcon-7b" in str(module.__class__).lower():
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
elif "falcon-40b" in str(module.__class__).lower():
scales_list.append(_auto_get_scale(
prev_op=module.ln_attn,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
scales_list.append(_auto_get_scale(
prev_op=module.ln_mlp,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module,
kwargs=module_kwargs,
))
else:
raise NotImplementedError("Unknown Falcon architecture, currently only falcon-7b and falcon-40b are supported")
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.act,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
else:
raise NotImplementedError(f"{type(module)} not supported yet!")
return scales_list return scales_list
......
import torch
import torch.nn as nn
import tqdm
import gc
import functools
from collections import defaultdict
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from .auto_scale import auto_scale_block, apply_scale
from .auto_clip import auto_clip_block, apply_clip
from ..models import MptAWQForCausalLM
__all__ = ["run_awq"]
def get_named_linears(module):
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
def get_blocks(model):
if isinstance(model, LlamaForCausalLM):
layers = model.model.layers
elif isinstance(model, OPTForCausalLM):
layers = model.model.decoder.layers
elif isinstance(model, BloomForCausalLM):
layers = model.transformer.h
elif "mpt" in str(model.__class__).lower():
layers = MptAWQForCausalLM.get_model_layers(model)
elif "falcon" in str(model.__class__).lower():
layers = model.transformer.h
else:
raise NotImplementedError(type(model))
return layers
def move_embed(model, device):
if isinstance(model, LlamaForCausalLM):
model.model.embed_tokens = model.model.embed_tokens.to(device)
elif isinstance(model, OPTForCausalLM):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
elif isinstance(model, BloomForCausalLM):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device)
elif "mpt" in str(model.__class__).lower():
MptAWQForCausalLM.move_embed(model, device)
elif "falcon" in str(model.__class__).lower():
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
else:
raise NotImplementedError(type(model))
@torch.no_grad()
def run_awq(
model, enc,
w_bit, q_config,
n_samples=512, seqlen=512,
auto_scale=True, mse_range=True,
# some configs for ablation study
calib_data="pileval",
):
from ..utils.calib_data import get_calib_dataset
from ..utils.module import append_str_prefix, get_op_name
layers = get_blocks(model)
samples = get_calib_dataset(
data=calib_data, tokenizer=enc, n_samples=n_samples, block_size=seqlen)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
layers[0] = layers[0].cuda()
move_embed(model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps.append(inp)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
layers[0] = Catcher(layers[0])
try:
model(samples.to(next(model.parameters()).device))
except ValueError: # work with early exit
pass
del samples
layers[0] = layers[0].module # restore
inps = inps[0]
layers[0] = layers[0].cpu()
move_embed(model, "cpu")
gc.collect()
torch.cuda.empty_cache()
awq_results = {
"scale": [],
"clip": [],
}
# solve layer by layer
for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."):
layer = layers[i]
layer = layer.cuda()
named_linears = get_named_linears(layer)
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
feat_dict=input_feat)))
inps = inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
inps = layer(inps, **layer_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
# Clear GPU memory
torch.cuda.empty_cache()
if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block(
layer, layer_kwargs,
w_bit=w_bit, q_config=q_config,
input_feat=input_feat,
)
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")
# Clear GPU memory
torch.cuda.empty_cache()
if mse_range:
clip_list = auto_clip_block(layer,
w_bit=w_bit, q_config=q_config,
input_feat=input_feat,)
apply_clip(layer, clip_list)
# append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")
layer = layer.cpu()
# Haotian: check activation replacement
del input_feat
gc.collect()
torch.cuda.empty_cache()
return awq_results
def apply_awq(model, awq_results):
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
...@@ -4,7 +4,6 @@ from tqdm import tqdm ...@@ -4,7 +4,6 @@ from tqdm import tqdm
import gc import gc
from .qmodule import ScaledActivation from .qmodule import ScaledActivation
from ..utils.module import set_op_by_name from ..utils.module import set_op_by_name
from ..models import MptAWQForCausalLM
from transformers.models.bloom.modeling_bloom import BloomBlock from transformers.models.bloom.modeling_bloom import BloomBlock
...@@ -30,7 +29,7 @@ def scale_activations(module): ...@@ -30,7 +29,7 @@ def scale_activations(module):
return return
# get activation scale # get activation scale
scale_dict = MptAWQForCausalLM.get_act_for_scaling(module) scale_dict = MptAWQForCausalLM().get_act_for_scaling(module)
scale_like = torch.ones(scale_dict['scale_shape'], dtype=dtype, device=device) scale_like = torch.ones(scale_dict['scale_shape'], dtype=dtype, device=device)
# scale activation # scale activation
......
import torch.nn as nn
def get_named_linears(module):
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
def get_op_by_name(module, op_name): def get_op_by_name(module, op_name):
# get the op by its name relative to the module # get the op by its name relative to the module
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment