Unverified Commit 9e8e28b2 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FEAT: add llava to autoawq (#250)


Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
parent 727172e9
...@@ -78,6 +78,7 @@ The detailed support list: ...@@ -78,6 +78,7 @@ The detailed support list:
| BigCode | 1B/7B/15B | | BigCode | 1B/7B/15B |
| GPT NeoX | 20B | | GPT NeoX | 20B |
| GPT-J | 6B | | GPT-J | 6B |
| Llava | 7B/13B |
## Usage ## Usage
......
...@@ -10,4 +10,5 @@ from .gpt_neox import GPTNeoXAWQForCausalLM ...@@ -10,4 +10,5 @@ from .gpt_neox import GPTNeoXAWQForCausalLM
from .aquila import AquilaAWQForCausalLM from .aquila import AquilaAWQForCausalLM
from .yi import YiAWQForCausalLM from .yi import YiAWQForCausalLM
from .qwen import QwenAWQForCausalLM from .qwen import QwenAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM from .llava import LlavaAWQForCausalLM
\ No newline at end of file from .mixtral import MixtralAWQForCausalLM
...@@ -18,7 +18,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -18,7 +18,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"gpt_neox": GPTNeoXAWQForCausalLM, "gpt_neox": GPTNeoXAWQForCausalLM,
"aquila": AquilaAWQForCausalLM, "aquila": AquilaAWQForCausalLM,
"Yi": YiAWQForCausalLM, "Yi": YiAWQForCausalLM,
"qwen": QwenAWQForCausalLM "qwen": QwenAWQForCausalLM,
"llava": LlavaAWQForCausalLM,
} }
def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs): def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs):
......
import os import os
import gc import gc
import json import json
import torch import torch
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import List, Union from typing import List, Union
from safetensors.torch import save_file from safetensors.torch import save_file
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from awq.quantize.quantizer import AwqQuantizer import transformers
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import ( from awq.utils.module import (
get_named_linears, get_named_linears,
...@@ -18,19 +20,46 @@ from awq.utils.module import ( ...@@ -18,19 +20,46 @@ from awq.utils.module import (
exclude_layers_to_not_quantize, exclude_layers_to_not_quantize,
) )
from transformers import ( from transformers import (
AutoModelForCausalLM,
AutoConfig, AutoConfig,
PreTrainedModel, PreTrainedModel,
PretrainedConfig, PretrainedConfig,
AutoProcessor,
CLIPImageProcessor,
) )
from accelerate.big_modeling import ( from accelerate.big_modeling import (
init_empty_weights, init_empty_weights,
infer_auto_device_map,
load_checkpoint_and_dispatch, load_checkpoint_and_dispatch,
) )
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name
# Since we support different `AutoModelForxxx` from transformers
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
"mpt": "AutoModelForCausalLM",
"llama": "AutoModelForCausalLM",
"opt": "AutoModelForCausalLM",
"RefinedWeb": "AutoModelForCausalLM",
"RefinedWebModel": "AutoModelForCausalLM",
"falcon": "AutoModelForCausalLM",
"bloom": "AutoModelForCausalLM",
"gptj": "AutoModelForCausalLM",
"gpt_bigcode": "AutoModelForCausalLM",
"mistral": "AutoModelForCausalLM",
"mixtral": "AutoModelForCausalLM",
"gpt_neox": "AutoModelForCausalLM",
"aquila": "AutoModelForCausalLM",
"Yi": "AutoModelForCausalLM",
"qwen": "AutoModelForCausalLM",
"llava": "AutoModelForVision2Seq",
}
class BaseAWQForCausalLM(nn.Module): class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, config, quant_config): def __init__(self, model, model_type, is_quantized, config, quant_config, processor):
super().__init__() super().__init__()
self.model:PreTrainedModel = model self.model:PreTrainedModel = model
self.model_type:str = model_type self.model_type:str = model_type
...@@ -38,6 +67,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -38,6 +67,7 @@ class BaseAWQForCausalLM(nn.Module):
self.search_result = None self.search_result = None
self.config: PretrainedConfig = config self.config: PretrainedConfig = config
self.quant_config: AwqConfig = quant_config self.quant_config: AwqConfig = quant_config
self.processor: CLIPImageProcessor = processor
def to(self, device: str): def to(self, device: str):
return self.model.to(device) return self.model.to(device)
...@@ -79,6 +109,10 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -79,6 +109,10 @@ class BaseAWQForCausalLM(nn.Module):
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict()) self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
self.quant_config.save_pretrained(save_dir) self.quant_config.save_pretrained(save_dir)
# Vision transformers have a processor
if self.processor is not None:
self.processor.save_pretrained(save_dir)
# Remove empty state dict # Remove empty state dict
default_paths = [f'{save_dir}/model.safetensors', f'{save_dir}/pytorch_model.bin'] default_paths = [f'{save_dir}/model.safetensors', f'{save_dir}/pytorch_model.bin']
for path in default_paths: for path in default_paths:
...@@ -118,8 +152,16 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -118,8 +152,16 @@ class BaseAWQForCausalLM(nn.Module):
self, model_path, '', safetensors, trust_remote_code=trust_remote_code self, model_path, '', safetensors, trust_remote_code=trust_remote_code
) )
target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
target_cls = getattr(transformers, target_cls_name)
processor = None
if target_cls_name == "AutoModelForVision2Seq":
processor = AutoProcessor.from_pretrained(model_weights_path)
processor: CLIPImageProcessor = processor.image_processor
# If not quantized, must load with AutoModelForCausalLM # If not quantized, must load with AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained( model = target_cls.from_pretrained(
model_weights_path, model_weights_path,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
...@@ -130,7 +172,8 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -130,7 +172,8 @@ class BaseAWQForCausalLM(nn.Module):
model.eval() model.eval()
return self(model, model_type, is_quantized=False, config=config, quant_config=quant_config) return self(model, model_type, is_quantized=False, config=config,
quant_config=quant_config, processor=processor)
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename='', def from_quantized(self, model_path, model_type, model_filename='',
...@@ -145,10 +188,13 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -145,10 +188,13 @@ class BaseAWQForCausalLM(nn.Module):
trust_remote_code, max_new_tokens=max_new_tokens, trust_remote_code, max_new_tokens=max_new_tokens,
**config_kwargs **config_kwargs
) )
target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
target_cls = getattr(transformers, target_cls_name)
# [STEP 3] Load model # [STEP 3] Load model
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code) model = target_cls.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Prepare WQLinear layers, replace nn.Linear # Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, quant_config.version) self._load_quantized_modules(self, model, quant_config, quant_config.version)
...@@ -170,7 +216,8 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -170,7 +216,8 @@ class BaseAWQForCausalLM(nn.Module):
if fuse_layers: if fuse_layers:
self.fuse_layers(model) self.fuse_layers(model)
return self(model, model_type, is_quantized=is_quantized, config=config, quant_config=quant_config) return self(model, model_type, is_quantized=is_quantized, config=config,
quant_config=quant_config, processor=None)
def _load_config(self, model_path, model_filename, safetensors=True, def _load_config(self, model_path, model_filename, safetensors=True,
version="GEMM", trust_remote_code=True, max_new_tokens=4096, version="GEMM", trust_remote_code=True, max_new_tokens=4096,
...@@ -197,7 +244,10 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -197,7 +244,10 @@ class BaseAWQForCausalLM(nn.Module):
# Load model config and set max generation length # Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'): if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs) config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
config.max_new_tokens = getattr(config, self.max_new_tokens_key) config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
# To add the generate support for Multi-modal models as well
if hasattr(config, "text_config"):
config.text_config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
else: else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs) config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
......
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
)
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration as OldLlavaForConditionalGeneration
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlavaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldLlavaForConditionalGeneration):
fuser = LlavaFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldLlavaForConditionalGeneration):
return model.language_model.model.layers
@staticmethod
def get_act_for_scaling(module: OldLlamaDecoderLayer):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: OldLlavaForConditionalGeneration, device: str):
model.language_model.model.embed_tokens = model.get_input_embeddings().to(device)
@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# linear 1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
# linear 2
layers.append(dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
return layers
class LlavaFuser:
def __init__(self, model: OldLlavaForConditionalGeneration):
self.model = model.language_model
self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldLlamaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
self.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
import requests
import torch
from PIL import Image
from awq import AutoAWQForCausalLM
from transformers import AutoProcessor
quant_path = "ybelkada/llava-1.5-7b-hf-awq"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, safetensors=True, device_map={"": 0})
processor = AutoProcessor.from_pretrained(quant_path)
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
# Generate output
generation_output = model.generate(
**inputs,
max_new_tokens=512
)
print(processor.decode(generation_output[0], skip_special_tokens=True))
\ No newline at end of file
import torch
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = "llava-hf/llava-1.5-7b-hf"
quant_path = "llava-1.5-7b-hf-awq"
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version":"GEMM"}
# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, safetensors=True, torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Quantize
model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment