Unverified Commit 7fbe9bbc authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #1 from jamesdborin/new_model/gptj

Add GPTJ Support
parents 3a8072a1 50d1025f
......@@ -2,4 +2,5 @@ from .mpt import MptAWQForCausalLM
from .llama import LlamaAWQForCausalLM
from .opt import OptAWQForCausalLM
from .falcon import FalconAWQForCausalLM
from .bloom import BloomAWQForCausalLM
\ No newline at end of file
from .bloom import BloomAWQForCausalLM
from .gptj import GPTJAWQForCausalLM
\ No newline at end of file
......@@ -8,7 +8,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"opt": OptAWQForCausalLM,
"RefinedWeb": FalconAWQForCausalLM,
"RefinedWebModel": FalconAWQForCausalLM,
"bloom": BloomAWQForCausalLM
"bloom": BloomAWQForCausalLM,
"gptj": GPTJAWQForCausalLM
}
def check_and_get_model_type(model_dir, trust_remote_code=True):
......
......@@ -113,8 +113,8 @@ class BaseAWQForCausalLM(nn.Module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps.append(inp)
def forward(self, hijacked_inputs, **kwargs):
inps.append(hijacked_inputs)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
......@@ -358,4 +358,4 @@ class BaseAWQForCausalLM(nn.Module):
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
\ No newline at end of file
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
from .base import BaseAWQForCausalLM
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock
class GPTJAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTJBlock"
max_new_tokens_key = "n_positions"
@staticmethod
def get_model_layers(model: GPTJForCausalLM):
return model.transformer.h
@staticmethod
def get_act_for_scaling(module: GPTJBlock):
return dict(
is_scalable=True,
scale_name="mlp.act",
scale_layer=module.mlp.act,
scale_shape=module.mlp.fc_in.out_features
)
@staticmethod
def move_embed(model: GPTJForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device)
@staticmethod
def get_layers_for_scaling(module: GPTJBlock, input_feat, module_kwargs):
layers = []
# attention input + linear 1
layers.append(dict(
prev_op=module.ln_1,
layers=[module.attn.q_proj,
module.attn.k_proj, module.attn.v_proj, module.mlp.fc_in],
inp=input_feat['attn.q_proj'],
module2inspect=module,
kwargs=module_kwargs
))
# attention out
layers.append(dict(
prev_op=module.attn.v_proj,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj'],
))
# linear 2
layers.append(dict(
prev_op=module.mlp.act,
layers=[module.mlp.fc_out],
inp=input_feat['mlp.fc_out'],
))
return layers
\ No newline at end of file
......@@ -5,7 +5,7 @@ import torch.nn as nn
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import NewGELUActivation
from .qmodule import ScaledActivation
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
......@@ -79,7 +79,7 @@ def scale_fc_fc(fc1, fc2, scales):
@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
assert isinstance(gelu, nn.GELU) or isinstance(gelu, BloomGelu)
assert any(isinstance(gelu,t) for t in [nn.GELU, BloomGelu, NewGELUActivation])
assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
......@@ -195,7 +195,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
scale_ln_fcs(prev_op, layers, scales)
elif isinstance(prev_op, nn.GELU) or isinstance(prev_op, BloomGelu):
elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment