Commit 5c2c85fe authored by Casper Hansen's avatar Casper Hansen
Browse files

Support Bloom models

parent dd3010fb
...@@ -2,3 +2,4 @@ from .mpt import MptAWQForCausalLM ...@@ -2,3 +2,4 @@ from .mpt import MptAWQForCausalLM
from .llama import LlamaAWQForCausalLM from .llama import LlamaAWQForCausalLM
from .opt import OptAWQForCausalLM from .opt import OptAWQForCausalLM
from .falcon import FalconAWQForCausalLM from .falcon import FalconAWQForCausalLM
from .bloom import BloomAWQForCausalLM
\ No newline at end of file
...@@ -7,7 +7,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -7,7 +7,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"llama": LlamaAWQForCausalLM, "llama": LlamaAWQForCausalLM,
"opt": OptAWQForCausalLM, "opt": OptAWQForCausalLM,
"RefinedWeb": FalconAWQForCausalLM, "RefinedWeb": FalconAWQForCausalLM,
"RefinedWebModel": FalconAWQForCausalLM "RefinedWebModel": FalconAWQForCausalLM,
"bloom": BloomAWQForCausalLM
} }
def check_and_get_model_type(model_dir, trust_remote_code=True): def check_and_get_model_type(model_dir, trust_remote_code=True):
......
from .base import BaseAWQForCausalLM
from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomBlock
class BloomAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "BloomBlock"
@staticmethod
def get_model_layers(model: BloomForCausalLM):
return model.transformer.h
@staticmethod
def get_act_for_scaling(module: BloomBlock):
return dict(
is_scalable=True,
scale_name="mlp.gelu_impl",
scale_layer=module.mlp.gelu_impl,
scale_shape=module.mlp.dense_h_to_4h.out_features
)
@staticmethod
def move_embed(model: BloomForCausalLM, device: str):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device)
@staticmethod
def get_layers_for_scaling(module: BloomBlock, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module, kwargs=module_kwargs,
))
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# linear 1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module, kwargs=module_kwargs,
))
# linear 2
layers.append(dict(
prev_op=module.mlp.gelu_impl,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
return layers
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment