Commit 7e091fb1 authored by Casper's avatar Casper
Browse files

Initial refactor [WIP].

parent efea69e1
from .mpt import MptAWQForCausalLM
\ No newline at end of file
class BaseAWQForCausalLM:
def quantize():
pass
def save_quantized():
pass
def from_pretrained():
pass
def from_quantized():
pass
\ No newline at end of file
from .base import BaseAWQForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock"
def get_model_layers(model):
return model.transformer.blocks
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.norm_1,
layers=[module.attn.Wqkv],
inp=input_feat['attn.Wqkv'],
module2inspect=module.attn,
kwargs=module_kwargs
))
# attention output
layers.append(dict(
prev_op=module.attn.Wqkv,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj']
))
# linear 1
layers.append(dict(
rev_op=module.norm_2,
layers=[module.ffn.up_proj],
inp=input_feat['ffn.up_proj'],
module2inspect=module.ffn
))
# linear 2
layers.append(dict(
prev_op=module.ffn.act,
layers=[module.ffn.down_proj],
inp=input_feat['ffn.down_proj']
))
return layers
def get_act_for_scaling(module):
return dict(
scale_name="ffn.act",
scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features
)
def move_embed(model, device):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
\ No newline at end of file
......@@ -8,6 +8,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMS
from .qmodule import ScaledActivation
from ..utils.module import get_op_by_name, get_op_name, set_op_by_name
from ..models import MptAWQForCausalLM
__all__ = ["auto_scale_block", "apply_scale"]
......@@ -265,34 +266,11 @@ def auto_scale_block(module, module_kwargs,
inp=input_feat['mlp.dense_4h_to_h'],
))
elif "mpt" in str(module.__class__).lower():
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.norm_1,
layers=[module.attn.Wqkv],
inp=input_feat['attn.Wqkv'],
module2inspect=module.attn,
kwargs=module_kwargs,
))
# attn out
scales_list.append(_auto_get_scale(
prev_op=module.attn.Wqkv,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.norm_2,
layers=[module.ffn.up_proj],
inp=input_feat['ffn.up_proj'],
module2inspect=module.ffn,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.ffn.act,
layers=[module.ffn.down_proj],
inp=input_feat['ffn.down_proj'],
))
layers: list[dict] = MptAWQForCausalLM.get_layers_for_scaling(
module, input_feat, module_kwargs
)
layers_scaled = [_auto_get_scale(layer) for layer in layers]
scales_list.extend(layers_scaled)
elif "falcon" in str(module.__class__).lower():
# attn out
......
......@@ -11,6 +11,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
from .auto_scale import auto_scale_block, apply_scale
from .auto_clip import auto_clip_block, apply_clip
from ..models import MptAWQForCausalLM
__all__ = ["run_awq"]
......@@ -27,7 +28,7 @@ def get_blocks(model):
elif isinstance(model, BloomForCausalLM):
layers = model.transformer.h
elif "mpt" in str(model.__class__).lower():
layers = model.transformer.blocks
layers = MptAWQForCausalLM.get_model_layers(model)
elif "falcon" in str(model.__class__).lower():
layers = model.transformer.h
else:
......@@ -44,8 +45,7 @@ def move_embed(model, device):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device)
elif "mpt" in str(model.__class__).lower():
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
MptAWQForCausalLM.move_embed(model, device)
elif "falcon" in str(model.__class__).lower():
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
else:
......
......@@ -4,6 +4,7 @@ from tqdm import tqdm
import gc
from .qmodule import ScaledActivation
from ..utils.module import set_op_by_name
from ..models import MptAWQForCausalLM
from transformers.models.bloom.modeling_bloom import BloomBlock
......@@ -27,12 +28,15 @@ def scale_activations(module):
elif 'mptblock' in str(module.__class__.__name__).lower():
if isinstance(module.ffn.act, ScaledActivation):
return
c = module.ffn.up_proj.out_features
act = ScaledActivation(
module.ffn.act,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "ffn.act", act)
# get activation scale
scale_dict = MptAWQForCausalLM.get_act_for_scaling(module)
scale_like = torch.ones(scale_dict['scale_shape'], dtype=dtype, device=device)
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(module, scale_dict['scale_name'], scaled_act)
elif 'falcon' in str(module.__class__).lower():
if isinstance(module.mlp.act, ScaledActivation):
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment