Unverified Commit 7fbe9bbc authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #1 from jamesdborin/new_model/gptj

Add GPTJ Support
parents 3a8072a1 50d1025f
...@@ -2,4 +2,5 @@ from .mpt import MptAWQForCausalLM ...@@ -2,4 +2,5 @@ from .mpt import MptAWQForCausalLM
from .llama import LlamaAWQForCausalLM from .llama import LlamaAWQForCausalLM
from .opt import OptAWQForCausalLM from .opt import OptAWQForCausalLM
from .falcon import FalconAWQForCausalLM from .falcon import FalconAWQForCausalLM
from .bloom import BloomAWQForCausalLM from .bloom import BloomAWQForCausalLM
\ No newline at end of file from .gptj import GPTJAWQForCausalLM
\ No newline at end of file
...@@ -8,7 +8,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -8,7 +8,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"opt": OptAWQForCausalLM, "opt": OptAWQForCausalLM,
"RefinedWeb": FalconAWQForCausalLM, "RefinedWeb": FalconAWQForCausalLM,
"RefinedWebModel": FalconAWQForCausalLM, "RefinedWebModel": FalconAWQForCausalLM,
"bloom": BloomAWQForCausalLM "bloom": BloomAWQForCausalLM,
"gptj": GPTJAWQForCausalLM
} }
def check_and_get_model_type(model_dir, trust_remote_code=True): def check_and_get_model_type(model_dir, trust_remote_code=True):
......
...@@ -113,8 +113,8 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -113,8 +113,8 @@ class BaseAWQForCausalLM(nn.Module):
super().__init__() super().__init__()
self.module = module self.module = module
def forward(self, inp, **kwargs): def forward(self, hijacked_inputs, **kwargs):
inps.append(inp) inps.append(hijacked_inputs)
layer_kwargs.update(kwargs) layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference raise ValueError # early exit to break later inference
...@@ -358,4 +358,4 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -358,4 +358,4 @@ class BaseAWQForCausalLM(nn.Module):
# scale activation # scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like) scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act) set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
\ No newline at end of file
from .base import BaseAWQForCausalLM
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock
class GPTJAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTJBlock"
max_new_tokens_key = "n_positions"
@staticmethod
def get_model_layers(model: GPTJForCausalLM):
return model.transformer.h
@staticmethod
def get_act_for_scaling(module: GPTJBlock):
return dict(
is_scalable=True,
scale_name="mlp.act",
scale_layer=module.mlp.act,
scale_shape=module.mlp.fc_in.out_features
)
@staticmethod
def move_embed(model: GPTJForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device)
@staticmethod
def get_layers_for_scaling(module: GPTJBlock, input_feat, module_kwargs):
layers = []
# attention input + linear 1
layers.append(dict(
prev_op=module.ln_1,
layers=[module.attn.q_proj,
module.attn.k_proj, module.attn.v_proj, module.mlp.fc_in],
inp=input_feat['attn.q_proj'],
module2inspect=module,
kwargs=module_kwargs
))
# attention out
layers.append(dict(
prev_op=module.attn.v_proj,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj'],
))
# linear 2
layers.append(dict(
prev_op=module.mlp.act,
layers=[module.mlp.fc_out],
inp=input_feat['mlp.fc_out'],
))
return layers
\ No newline at end of file
...@@ -5,7 +5,7 @@ import torch.nn as nn ...@@ -5,7 +5,7 @@ import torch.nn as nn
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import NewGELUActivation
from .qmodule import ScaledActivation from .qmodule import ScaledActivation
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
...@@ -79,7 +79,7 @@ def scale_fc_fc(fc1, fc2, scales): ...@@ -79,7 +79,7 @@ def scale_fc_fc(fc1, fc2, scales):
@torch.no_grad() @torch.no_grad()
def scale_gelu_fc(gelu, fc, scales): def scale_gelu_fc(gelu, fc, scales):
assert isinstance(gelu, nn.GELU) or isinstance(gelu, BloomGelu) assert any(isinstance(gelu,t) for t in [nn.GELU, BloomGelu, NewGELUActivation])
assert isinstance(fc, nn.Linear) assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
...@@ -195,7 +195,7 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -195,7 +195,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scale_fc_fc(prev_op, layers[0], scales) scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)): elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
scale_ln_fcs(prev_op, layers, scales) scale_ln_fcs(prev_op, layers, scales)
elif isinstance(prev_op, nn.GELU) or isinstance(prev_op, BloomGelu): elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]):
new_module = ScaledActivation(prev_op, scales) new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module) set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales) scale_gelu_fc(prev_op, layers[0], scales)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment