Commit 984fd2f8 authored by Casper Hansen's avatar Casper Hansen
Browse files

Add GPT BigCode support (StarCoder)

parent a5e8b048
...@@ -3,4 +3,5 @@ from .llama import LlamaAWQForCausalLM ...@@ -3,4 +3,5 @@ from .llama import LlamaAWQForCausalLM
from .opt import OptAWQForCausalLM from .opt import OptAWQForCausalLM
from .falcon import FalconAWQForCausalLM from .falcon import FalconAWQForCausalLM
from .bloom import BloomAWQForCausalLM from .bloom import BloomAWQForCausalLM
from .gptj import GPTJAWQForCausalLM from .gptj import GPTJAWQForCausalLM
\ No newline at end of file from .gpt_bigcode import GptBigCodeAWQForCausalLM
\ No newline at end of file
...@@ -11,7 +11,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -11,7 +11,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"RefinedWebModel": FalconAWQForCausalLM, "RefinedWebModel": FalconAWQForCausalLM,
"falcon": FalconAWQForCausalLM, "falcon": FalconAWQForCausalLM,
"bloom": BloomAWQForCausalLM, "bloom": BloomAWQForCausalLM,
"gptj": GPTJAWQForCausalLM "gptj": GPTJAWQForCausalLM,
"gpt_bigcode": GptBigCodeAWQForCausalLM
} }
def check_and_get_model_type(model_dir, trust_remote_code=True): def check_and_get_model_type(model_dir, trust_remote_code=True):
......
from .base import BaseAWQForCausalLM
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM, GPTBigCodeBlock
class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTBigCodeBlock"
max_new_tokens_key = "n_positions"
@staticmethod
def get_model_layers(model: GPTBigCodeForCausalLM):
return model.transformer.h
@staticmethod
def get_act_for_scaling(module: GPTBigCodeBlock):
return dict(
is_scalable=True,
scale_name="mlp.act",
scale_layer=module.mlp.act,
scale_shape=module.mlp.c_fc.out_features
)
@staticmethod
def move_embed(model: GPTBigCodeForCausalLM, device):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.drop = model.transformer.drop.to(device)
@staticmethod
def get_layers_for_scaling(module:GPTBigCodeBlock, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.ln_1,
layers=[module.attn.c_attn],
inp=input_feat['attn.c_attn'],
module2inspect=module.attn,
kwargs=module_kwargs
))
# attention output
# layers.append(dict(
# prev_op=module.attn.c_attn,
# layers=[module.attn.c_proj],
# inp=input_feat['attn.c_proj']
# ))
# linear 1
layers.append(dict(
prev_op=module.ln_2,
layers=[module.mlp.c_fc],
inp=input_feat['mlp.c_fc'],
module2inspect=module.mlp
))
# linear 2
layers.append(dict(
prev_op=module.mlp.act,
layers=[module.mlp.c_proj],
inp=input_feat['mlp.c_proj']
))
return layers
...@@ -6,12 +6,14 @@ import logging ...@@ -6,12 +6,14 @@ import logging
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import NewGELUActivation from transformers.activations import NewGELUActivation, PytorchGELUTanh
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
__all__ = ["auto_scale_block", "apply_scale"] __all__ = ["auto_scale_block", "apply_scale"]
norms = [nn.LayerNorm, LlamaRMSNorm]
act_functions = [nn.GELU, BloomGelu, NewGELUActivation, PytorchGELUTanh]
@torch.no_grad() @torch.no_grad()
def get_weight_scale(weight, q_group_size=-1): def get_weight_scale(weight, q_group_size=-1):
...@@ -80,7 +82,7 @@ def scale_fc_fc(fc1, fc2, scales): ...@@ -80,7 +82,7 @@ def scale_fc_fc(fc1, fc2, scales):
@torch.no_grad() @torch.no_grad()
def scale_gelu_fc(gelu, fc, scales): def scale_gelu_fc(gelu, fc, scales):
assert any(isinstance(gelu,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]) assert any(isinstance(gelu,t) for t in act_functions)
assert isinstance(fc, nn.Linear) assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
...@@ -194,11 +196,11 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -194,11 +196,11 @@ def apply_scale(module, scales_list, input_feat_dict=None):
assert len(layers) == 1 assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales) scale_fc_fc(prev_op, layers[0], scales)
elif any(isinstance(prev_op,t) for t in [nn.LayerNorm, LlamaRMSNorm]) \ elif any(isinstance(prev_op,t) for t in norms) \
or 'rmsnorm' in str(prev_op.__class__).lower(): or 'rmsnorm' in str(prev_op.__class__).lower():
scale_ln_fcs(prev_op, layers, scales) scale_ln_fcs(prev_op, layers, scales)
elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]): elif any(isinstance(prev_op,t) for t in act_functions):
new_module = ScaledActivation(prev_op, scales) new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module) set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales) scale_gelu_fc(prev_op, layers[0], scales)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment