Commit d5878167 authored by mashun1's avatar mashun1
Browse files

llava-next

parents
Pipeline #2589 failed with stages
in 0 seconds
import os
AVAILABLE_MODELS = {
"llava_llama": "LlavaLlamaForCausalLM, LlavaConfig",
"llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig",
"llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig",
"llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig",
# "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig",
# Add other models as needed
}
for model_name, model_classes in AVAILABLE_MODELS.items():
try:
exec(f"from .language_model.{model_name} import {model_classes}")
except Exception as e:
print(f"Failed to import {model_name} from llava.language_model.{model_name}. Error: {e}")
"""
Usage:
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
"""
import argparse
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from llava import LlavaLlamaForCausalLM
def apply_delta(base_model_path, target_model_path, delta_path):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Loading delta")
delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
print("Applying delta")
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
if name not in base.state_dict():
assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
continue
if param.data.shape == base.state_dict()[name].shape:
param.data += base.state_dict()[name]
else:
assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
bparam = base.state_dict()[name]
param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
print("Saving target model")
delta.save_pretrained(target_model_path)
delta_tokenizer.save_pretrained(target_model_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)
args = parser.parse_args()
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import shutil
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from llava.model import *
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.utils import rank0_print
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", torch_dtype="float16",attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs):
kwargs["device_map"] = device_map
if load_8bit:
kwargs["load_in_8bit"] = True
elif load_4bit:
kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
elif torch_dtype == "float16":
kwargs["torch_dtype"] = torch.float16
elif torch_dtype == "bfloat16":
kwargs["torch_dtype"] = torch.bfloat16
else:
import pdb;pdb.set_trace()
if customized_config is not None:
kwargs["config"] = customized_config
if "multimodal" in kwargs:
if kwargs["multimodal"] is True:
is_multimodal = True
kwargs.pop("multimodal")
else:
is_multimodal = False
if "llava" in model_name.lower() or is_multimodal:
# Load LLaVA model
if "lora" in model_name.lower() and model_base is None:
warnings.warn(
"There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
)
if "lora" in model_name.lower() and model_base is not None:
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
rank0_print("Loading LLaVA from base model...")
if "mixtral" in model_name.lower():
from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
elif "mistral" in model_name.lower():
from llava.model.language_model.llava_mistral import LlavaMistralConfig
lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
elif "gemma" in model_name.lower():
from llava.model.language_model.llava_gemma import LlavaGemmaConfig
lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
else:
from llava.model.language_model.llava_llama import LlavaConfig
lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
if model.lm_head.weight.shape[0] != token_num:
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
rank0_print("Loading additional LLaVA weights...")
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
else:
# this is probably from HF Hub
from huggingface_hub import hf_hub_download
def load_from_hf(repo_id, filename, subfolder=None):
cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
return torch.load(cache_file, map_location="cpu")
non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
non_lora_trainables = {(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()}
if any(k.startswith("model.model.") for k in non_lora_trainables):
non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
model.load_state_dict(non_lora_trainables, strict=False)
from peft import PeftModel
rank0_print("Loading LoRA weights...")
model = PeftModel.from_pretrained(model, model_path)
rank0_print("Merging LoRA weights...")
model = model.merge_and_unload()
rank0_print("Model is loaded...")
elif model_base is not None: # this may be mm projector only, loading projector with preset language mdoel
rank0_print(f"Loading LLaVA from base model {model_base}...")
if "mixtral" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
elif "gemma" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
elif (
"wizardlm-2" in model_name.lower()
and "vicuna" in model_name.lower()
or "llama" in model_name.lower()
or "yi" in model_name.lower()
or "nous-hermes" in model_name.lower()
or "llava-v1.6-34b" in model_name.lower()
or "llava-v1.5" in model_name.lower()
):
from llava.model.language_model.llava_llama import LlavaConfig
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if customized_config is None:
llava_cfg = LlavaConfig.from_pretrained(model_path)
if "v1.5" in model_name.lower():
llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
else:
llava_cfg = customized_config
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
llava_cfg = LlavaConfig.from_pretrained(model_path)
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=llava_cfg, **kwargs)
else:
raise ValueError(f"Model {model_name} not supported")
mm_projector_weights = torch.load(os.path.join(model_path, "mm_projector.bin"), map_location="cpu")
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
model.load_state_dict(mm_projector_weights, strict=False)
else:
rank0_print(f"Loaded LLaVA model: {model_path}")
if "mixtral" in model_name.lower():
from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if customized_config is None:
llava_cfg = LlavaMixtralConfig.from_pretrained(model_path)
else:
llava_cfg = customized_config
if overwrite_config is not None:
rank0_print(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(llava_cfg, k, v)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
elif (
"wizardlm-2" in model_name.lower()
and "vicuna" in model_name.lower()
or "llama" in model_name.lower()
or "yi" in model_name.lower()
or "nous-hermes" in model_name.lower()
or "llava-v1.6-34b" in model_name.lower()
or "llava-v1.5" in model_name.lower()
):
from llava.model.language_model.llava_llama import LlavaConfig
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if customized_config is None:
llava_cfg = LlavaConfig.from_pretrained(model_path)
if "v1.5" in model_name.lower():
llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
else:
llava_cfg = customized_config
if overwrite_config is not None:
rank0_print(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(llava_cfg, k, v)
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
elif "qwen" in model_name.lower() or "quyen" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path)
if "moe" in model_name.lower() or "A14B" in model_name.lower():
from llava.model.language_model.llava_qwen_moe import LlavaQwenMoeConfig
if overwrite_config is not None:
llava_cfg = LlavaQwenMoeConfig.from_pretrained(model_path)
rank0_print(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(llava_cfg, k, v)
model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
else:
model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
else:
from llava.model.language_model.llava_qwen import LlavaQwenConfig
if overwrite_config is not None:
llava_cfg = LlavaQwenConfig.from_pretrained(model_path)
rank0_print(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(llava_cfg, k, v)
model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
else:
model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
elif "gemma" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = LlavaGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
else:
try:
from llava.model.language_model.llava_llama import LlavaConfig
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if customized_config is None:
llava_cfg = LlavaConfig.from_pretrained(model_path)
if "v1.5" in model_path.lower():
llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
else:
llava_cfg = customized_config
if overwrite_config is not None:
rank0_print(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(llava_cfg, k, v)
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
except:
raise ValueError(f"Model {model_name} not supported")
else:
# Load language model
if model_base is not None:
# PEFT model
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
print(f"Loading LoRA weights from {model_path}")
model = PeftModel.from_pretrained(model, model_path)
print(f"Merging weights")
model = model.merge_and_unload()
print("Convert to FP16...")
model.to(torch.float16)
else:
use_fast = False
if "mpt" in model_name.lower().replace("prompt", ""):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
rank0_print(f"Model Class: {model.__class__.__name__}")
image_processor = None
if "llava" in model_name.lower() or is_multimodal:
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model(device_map=device_map)
if device_map != "auto":
vision_tower.to(device="cuda", dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
elif hasattr(model.config, "max_position_embeddings"):
context_len = model.config.max_position_embeddings
elif hasattr(model.config, "tokenizer_model_max_length"):
context_len = model.config.tokenizer_model_max_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
"""
Usage:
python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
"""
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from llava.model import *
from llava.model.utils import auto_upgrade
def consolidate_ckpt(src_path, dst_path):
print("Loading model")
auto_upgrade(src_path)
src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
src_model.save_pretrained(dst_path)
src_tokenizer.save_pretrained(dst_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, required=True)
parser.add_argument("--dst", type=str, required=True)
args = parser.parse_args()
consolidate_ckpt(args.src, args.dst)
# Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaGemmaConfig(GemmaConfig):
model_type = "llava_gemma"
class LlavaGemmaModel(LlavaMetaModel, GemmaModel):
config_class = LlavaGemmaConfig
def __init__(self, config: GemmaConfig):
super(LlavaGemmaModel, self).__init__(config)
class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaGemmaConfig
def __init__(self, config):
super(GemmaForCausalLM, self).__init__(config)
self.model = LlavaGemmaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
AutoConfig.register("llava_gemma", LlavaGemmaConfig)
AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
from torch.nn import CrossEntropyLoss
# , LlamaModel, LlamaForCausalLM, GenerationConfig
# from .modeling_llama import LlamaModel, LlamaForCausalLM
from transformers import LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaConfig(LlamaConfig):
model_type = "llava_llama"
temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
max_new_tokens: int = 1024
do_sample: bool = False
top_p: Optional[float] = None
# rope_scaling: Optional[dict] = {}
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
config_class = LlavaConfig
def __init__(self, config: LlamaConfig):
super(LlavaLlamaModel, self).__init__(config)
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaConfig
def __init__(self, config):
LlamaForCausalLM.__init__(self, config)
# configure default generation settings
config.model_type = "llava_llama"
# config.rope_scaling = None
self.model = LlavaLlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
modalities: Optional[List[str]] = ["image"],
dpo_forward: Optional[bool] = None,
cache_position=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
if dpo_forward:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
return logits, labels
else:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
modalities: Optional[List[str]] = ["image"],
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaMistralConfig(MistralConfig):
model_type = "llava_mistral"
temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
max_new_tokens: int = 1024
do_sample: bool = False
top_p: Optional[float] = None
class LlavaMistralModel(LlavaMetaModel, MistralModel):
config_class = LlavaMistralConfig
def __init__(self, config: MistralConfig):
super(LlavaMistralModel, self).__init__(config)
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaMistralConfig
def __init__(self, config):
super(MistralForCausalLM, self).__init__(config)
config.model_type = "llava_mistral"
config.rope_scaling = None
self.model = LlavaMistralModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
AutoConfig.register("llava_mistral", LlavaMistralConfig)
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaMixtralConfig(MixtralConfig):
model_type = "llava_mixtral"
class LlavaMixtralModel(LlavaMetaModel, MixtralModel):
config_class = LlavaMixtralConfig
def __init__(self, config: MixtralConfig):
super(LlavaMixtralModel, self).__init__(config)
class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaMixtralConfig
def __init__(self, config):
super(MixtralForCausalLM, self).__init__(config)
config.model_type = "llava_mixtral"
config.rope_scaling = None
self.model = LlavaMixtralModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
modalities: Optional[List[str]] = ["image"],
dpo_forward: Optional[bool] = None,
cache_position=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
if dpo_forward:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
return logits, labels
else:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
modalities: Optional[List[str]] = ["image"],
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
AutoConfig.register("llava_mixtral", LlavaMixtralConfig)
AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaMptConfig(MptConfig):
model_type = "llava_mpt"
class LlavaMptModel(LlavaMetaModel, MptModel):
config_class = LlavaMptConfig
def __init__(self, config: MptConfig):
config.hidden_size = config.d_model
super(LlavaMptModel, self).__init__(config)
def embed_tokens(self, x):
return self.wte(x)
class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaMptConfig
supports_gradient_checkpointing = True
def __init__(self, config):
super(MptForCausalLM, self).__init__(config)
config.model_type = "llava_mpt"
config.rope_scaling = None
self.generation_config = GenerationConfig(
temperature=0.0,
max_new_tokens=1024,
do_sample=False,
top_p=None,
)
self.transformer = LlavaMptModel(config)
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.transformer
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlavaMptModel):
module.gradient_checkpointing = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position=None,
images=None,
):
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
return super().forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
_inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
_inputs["images"] = images
return _inputs
AutoConfig.register("llava_mpt", LlavaMptConfig)
AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
"""
Usage:
python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
"""
import argparse
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from llava.model.utils import auto_upgrade
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Loading target model")
auto_upgrade(target_model_path)
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Calculating delta")
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
if name not in base.state_dict():
assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
continue
if param.data.shape == base.state_dict()[name].shape:
param.data -= base.state_dict()[name]
else:
assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
bparam = base.state_dict()[name]
param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
print("Saving delta")
if hub_repo_id:
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
else:
kwargs = {}
target.save_pretrained(delta_path, **kwargs)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
target_tokenizer.save_pretrained(delta_path, **kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)
parser.add_argument("--hub-repo-id", type=str, default=None)
args = parser.parse_args()
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
import os
from .clip_encoder import CLIPVisionTower
from .imagebind import ImageBindWrapper
from .open_clip_encoder import OpenCLIPVisionTower
from .hf_vision import HFVisionTower
from .siglip_encoder import SigLipVisionTower
from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
from .mlcd_encoder import MLCDVisionTower, MLCDVisionTowerS2
# from .eva_clip.eva_clip_encoder import EvaClipVisionTower
# from .dev_eva_clip.eva_vit import EvaViTWrapper
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
is_absolute_path_exists = os.path.exists(vision_tower)
use_s2 = getattr(vision_tower_cfg, "s2", False)
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
if use_s2:
return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
else:
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "siglip" in vision_tower:
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
elif vision_tower.startswith("hf:"):
return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif vision_tower in ["imagebind_huge"]:
return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs)
elif vision_tower.startswith("open_clip_hub"):
return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "mlcd-vit-bigG-patch14" in vision_tower:
if use_s2:
return MLCDVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
else:
return MLCDVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
# elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower():
# return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
# elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]:
# return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs)
raise ValueError(f"Unknown vision tower: {vision_tower}")
This diff is collapsed.
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .tokenizer import SimpleTokenizer, tokenize
from .transform import image_transform
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment