Unverified Commit dafd5951 authored by isaac-vidas's avatar isaac-vidas Committed by GitHub
Browse files

[`Llava`] Update convert_llava_weights_to_hf.py script (#28617)

* Update convert_llava_weights_to_hf.py script

* Remove config update of adding padding to `vocab_size` and `text_config.vocab_size` which causes `ValueError` exception.
* Remove keys that ends with `inv_freq` from the state dict.
* Add examples and instructions for creating `model_state_dict.bin` that can be used by the script.

* Update convert_llava_weights_to_hf.py

* Update convert_vipllava_weights_to_hf.py
parent deb2b590
......@@ -27,6 +27,25 @@ from transformers import (
)
EPILOG_TXT = """Example:
python transformers/src/transformers/models/llava/convert_llava_weights_to_hf.py --text_model_id lmsys/vicuna-7b-v1.5 --vision_model_id openai/clip-vit-large-patch14-336 --output_hub_path org/llava-v1.5-7b-conv --old_state_dict_id liuhaotian/llava-v1.5-7b
Example for creating the old state dict file with Python:
import torch
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
# load model
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
model = LlavaLlamaForCausalLM.from_pretrained("liuhaotian/llava-v1.5-7b", low_cpu_mem_usage=True, **kwargs)
# load vision tower
model.get_vision_tower().load_model()
# Save state dict
torch.save(model.state_dict(), "tmp/hf_models/llava-v1.5-7b/model_state_dict.bin")
"""
KEYS_TO_MODIFY_MAPPING = {
"model.vision_tower.": "",
"model.mm_projector": "multi_modal_projector",
......@@ -42,6 +61,8 @@ KEYS_TO_MODIFY_MAPPING = {
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".inv_freq"):
continue
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
......@@ -93,15 +114,16 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))),
dim=0,
)
model.config.vocab_size = model.config.vocab_size + pad_shape
model.config.text_config.vocab_size = model.config.text_config.vocab_size + pad_shape
model.push_to_hub(output_hub_path)
processor.push_to_hub(output_hub_path)
def main():
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
epilog=EPILOG_TXT,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--text_model_id",
help="Hub location of the text model",
......
......@@ -46,6 +46,8 @@ KEYS_TO_MODIFY_MAPPING = {
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".inv_freq"):
continue
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
......@@ -97,8 +99,6 @@ def convert_vipllava_llama_to_hf(text_model_id, vision_model_id, output_hub_path
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))),
dim=0,
)
model.config.vocab_size = model.config.vocab_size + pad_shape
model.config.text_config.vocab_size = model.config.text_config.vocab_size + pad_shape
model.push_to_hub(output_hub_path)
processor.push_to_hub(output_hub_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment