Unverified Commit 9e8e28b2 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FEAT: add llava to autoawq (#250)


Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
parent 727172e9
......@@ -78,6 +78,7 @@ The detailed support list:
| BigCode | 1B/7B/15B |
| GPT NeoX | 20B |
| GPT-J | 6B |
| Llava | 7B/13B |
## Usage
......
......@@ -10,4 +10,5 @@ from .gpt_neox import GPTNeoXAWQForCausalLM
from .aquila import AquilaAWQForCausalLM
from .yi import YiAWQForCausalLM
from .qwen import QwenAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM
\ No newline at end of file
from .llava import LlavaAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM
......@@ -18,7 +18,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"gpt_neox": GPTNeoXAWQForCausalLM,
"aquila": AquilaAWQForCausalLM,
"Yi": YiAWQForCausalLM,
"qwen": QwenAWQForCausalLM
"qwen": QwenAWQForCausalLM,
"llava": LlavaAWQForCausalLM,
}
def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs):
......
import os
import gc
import json
import torch
import torch.nn as nn
from tqdm import tqdm
from typing import List, Union
from safetensors.torch import save_file
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download
from awq.quantize.quantizer import AwqQuantizer
import transformers
from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import (
get_named_linears,
......@@ -18,19 +20,46 @@ from awq.utils.module import (
exclude_layers_to_not_quantize,
)
from transformers import (
AutoModelForCausalLM,
AutoConfig,
PreTrainedModel,
PretrainedConfig,
AutoProcessor,
CLIPImageProcessor,
)
from accelerate.big_modeling import (
init_empty_weights,
infer_auto_device_map,
load_checkpoint_and_dispatch,
)
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name
# Since we support different `AutoModelForxxx` from transformers
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
"mpt": "AutoModelForCausalLM",
"llama": "AutoModelForCausalLM",
"opt": "AutoModelForCausalLM",
"RefinedWeb": "AutoModelForCausalLM",
"RefinedWebModel": "AutoModelForCausalLM",
"falcon": "AutoModelForCausalLM",
"bloom": "AutoModelForCausalLM",
"gptj": "AutoModelForCausalLM",
"gpt_bigcode": "AutoModelForCausalLM",
"mistral": "AutoModelForCausalLM",
"mixtral": "AutoModelForCausalLM",
"gpt_neox": "AutoModelForCausalLM",
"aquila": "AutoModelForCausalLM",
"Yi": "AutoModelForCausalLM",
"qwen": "AutoModelForCausalLM",
"llava": "AutoModelForVision2Seq",
}
class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, config, quant_config):
def __init__(self, model, model_type, is_quantized, config, quant_config, processor):
super().__init__()
self.model:PreTrainedModel = model
self.model_type:str = model_type
......@@ -38,6 +67,7 @@ class BaseAWQForCausalLM(nn.Module):
self.search_result = None
self.config: PretrainedConfig = config
self.quant_config: AwqConfig = quant_config
self.processor: CLIPImageProcessor = processor
def to(self, device: str):
return self.model.to(device)
......@@ -79,6 +109,10 @@ class BaseAWQForCausalLM(nn.Module):
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
self.quant_config.save_pretrained(save_dir)
# Vision transformers have a processor
if self.processor is not None:
self.processor.save_pretrained(save_dir)
# Remove empty state dict
default_paths = [f'{save_dir}/model.safetensors', f'{save_dir}/pytorch_model.bin']
for path in default_paths:
......@@ -118,8 +152,16 @@ class BaseAWQForCausalLM(nn.Module):
self, model_path, '', safetensors, trust_remote_code=trust_remote_code
)
target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
target_cls = getattr(transformers, target_cls_name)
processor = None
if target_cls_name == "AutoModelForVision2Seq":
processor = AutoProcessor.from_pretrained(model_weights_path)
processor: CLIPImageProcessor = processor.image_processor
# If not quantized, must load with AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model = target_cls.from_pretrained(
model_weights_path,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
......@@ -130,7 +172,8 @@ class BaseAWQForCausalLM(nn.Module):
model.eval()
return self(model, model_type, is_quantized=False, config=config, quant_config=quant_config)
return self(model, model_type, is_quantized=False, config=config,
quant_config=quant_config, processor=processor)
@classmethod
def from_quantized(self, model_path, model_type, model_filename='',
......@@ -145,10 +188,13 @@ class BaseAWQForCausalLM(nn.Module):
trust_remote_code, max_new_tokens=max_new_tokens,
**config_kwargs
)
target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
target_cls = getattr(transformers, target_cls_name)
# [STEP 3] Load model
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
model = target_cls.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, quant_config.version)
......@@ -170,7 +216,8 @@ class BaseAWQForCausalLM(nn.Module):
if fuse_layers:
self.fuse_layers(model)
return self(model, model_type, is_quantized=is_quantized, config=config, quant_config=quant_config)
return self(model, model_type, is_quantized=is_quantized, config=config,
quant_config=quant_config, processor=None)
def _load_config(self, model_path, model_filename, safetensors=True,
version="GEMM", trust_remote_code=True, max_new_tokens=4096,
......@@ -197,7 +244,10 @@ class BaseAWQForCausalLM(nn.Module):
# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
config.max_new_tokens = getattr(config, self.max_new_tokens_key)
config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
# To add the generate support for Multi-modal models as well
if hasattr(config, "text_config"):
config.text_config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
......
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
)
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration as OldLlavaForConditionalGeneration
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlavaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldLlavaForConditionalGeneration):
fuser = LlavaFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldLlavaForConditionalGeneration):
return model.language_model.model.layers
@staticmethod
def get_act_for_scaling(module: OldLlamaDecoderLayer):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: OldLlavaForConditionalGeneration, device: str):
model.language_model.model.embed_tokens = model.get_input_embeddings().to(device)
@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# linear 1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
# linear 2
layers.append(dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
return layers
class LlavaFuser:
def __init__(self, model: OldLlavaForConditionalGeneration):
self.model = model.language_model
self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldLlamaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
self.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
import requests
import torch
from PIL import Image
from awq import AutoAWQForCausalLM
from transformers import AutoProcessor
quant_path = "ybelkada/llava-1.5-7b-hf-awq"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, safetensors=True, device_map={"": 0})
processor = AutoProcessor.from_pretrained(quant_path)
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
# Generate output
generation_output = model.generate(
**inputs,
max_new_tokens=512
)
print(processor.decode(generation_output[0], skip_special_tokens=True))
\ No newline at end of file
import torch
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = "llava-hf/llava-1.5-7b-hf"
quant_path = "llava-1.5-7b-hf-awq"
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version":"GEMM"}
# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, safetensors=True, torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Quantize
model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment