Commit d5878167 authored by mashun1's avatar mashun1
Browse files

llava-next

parents
Pipeline #2589 failed with stages
in 0 seconds
import argparse
import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def main(args):
# Model
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode))
else:
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
if "mpt" in model_name.lower():
roles = ("user", "assistant")
else:
roles = conv.roles
image = load_image(args.image_file)
image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().cuda()
while True:
try:
inp = input(f"{roles[0]}: ")
except EOFError:
inp = ""
if not inp:
print("exit...")
break
print(f"{roles[1]}: ", end="")
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
else:
inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
conv.append_message(conv.roles[0], inp)
image = None
else:
# later messages
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria])
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip()
conv.messages[-1][-1] = outputs
if args.debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
main(args)
"""
A controller manages distributed workers.
It sends worker addresses to clients.
"""
import argparse
import asyncio
import dataclasses
from enum import Enum, auto
import json
import logging
import time
from typing import List, Union
import threading
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import numpy as np
import requests
import uvicorn
from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
from llava.utils import build_logger, server_error_msg
logger = build_logger("controller", "controller.log")
class DispatchMethod(Enum):
LOTTERY = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, name):
if name == "lottery":
return cls.LOTTERY
elif name == "shortest_queue":
return cls.SHORTEST_QUEUE
else:
raise ValueError(f"Invalid dispatch method")
@dataclasses.dataclass
class WorkerInfo:
model_names: List[str]
speed: int
queue_length: int
check_heart_beat: bool
last_heart_beat: str
def heart_beat_controller(controller):
while True:
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
controller.remove_stable_workers_by_expiration()
class Controller:
def __init__(self, dispatch_method: str):
# Dict[str -> WorkerInfo]
self.worker_info = {}
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self,))
self.heart_beat_thread.start()
logger.info("Init controller")
def register_worker(self, worker_name: str, check_heart_beat: bool, worker_status: dict):
if worker_name not in self.worker_info:
logger.info(f"Register a new worker: {worker_name}")
else:
logger.info(f"Register an existing worker: {worker_name}")
if not worker_status:
worker_status = self.get_worker_status(worker_name)
if not worker_status:
return False
self.worker_info[worker_name] = WorkerInfo(worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], check_heart_beat, time.time())
logger.info(f"Register done: {worker_name}, {worker_status}")
return True
def get_worker_status(self, worker_name: str):
try:
r = requests.post(worker_name + "/worker_get_status", timeout=5)
except requests.exceptions.RequestException as e:
logger.error(f"Get status fails: {worker_name}, {e}")
return None
if r.status_code != 200:
logger.error(f"Get status fails: {worker_name}, {r}")
return None
return r.json()
def remove_worker(self, worker_name: str):
del self.worker_info[worker_name]
def refresh_all_workers(self):
old_info = dict(self.worker_info)
self.worker_info = {}
for w_name, w_info in old_info.items():
if not self.register_worker(w_name, w_info.check_heart_beat, None):
logger.info(f"Remove stale worker: {w_name}")
def list_models(self):
model_names = set()
for w_name, w_info in self.worker_info.items():
model_names.update(w_info.model_names)
return list(model_names)
def get_worker_address(self, model_name: str):
if self.dispatch_method == DispatchMethod.LOTTERY:
worker_names = []
worker_speeds = []
for w_name, w_info in self.worker_info.items():
if model_name in w_info.model_names:
worker_names.append(w_name)
worker_speeds.append(w_info.speed)
worker_speeds = np.array(worker_speeds, dtype=np.float32)
norm = np.sum(worker_speeds)
if norm < 1e-4:
return ""
worker_speeds = worker_speeds / norm
if True: # Directly return address
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
worker_name = worker_names[pt]
return worker_name
# Check status before returning
while True:
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
worker_name = worker_names[pt]
if self.get_worker_status(worker_name):
break
else:
self.remove_worker(worker_name)
worker_speeds[pt] = 0
norm = np.sum(worker_speeds)
if norm < 1e-4:
return ""
worker_speeds = worker_speeds / norm
continue
return worker_name
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
worker_names = []
worker_qlen = []
for w_name, w_info in self.worker_info.items():
if model_name in w_info.model_names:
worker_names.append(w_name)
worker_qlen.append(w_info.queue_length / w_info.speed)
if len(worker_names) == 0:
return ""
min_index = np.argmin(worker_qlen)
w_name = worker_names[min_index]
self.worker_info[w_name].queue_length += 1
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
return w_name
else:
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
def receive_heart_beat(self, worker_name: str, queue_length: int):
if worker_name not in self.worker_info:
logger.info(f"Receive unknown heart beat. {worker_name}")
return False
self.worker_info[worker_name].queue_length = queue_length
self.worker_info[worker_name].last_heart_beat = time.time()
logger.info(f"Receive heart beat. {worker_name}")
return True
def remove_stable_workers_by_expiration(self):
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
to_delete = []
for worker_name, w_info in self.worker_info.items():
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
to_delete.append(worker_name)
for worker_name in to_delete:
self.remove_worker(worker_name)
def worker_api_generate_stream(self, params):
worker_addr = self.get_worker_address(params["model"])
if not worker_addr:
logger.info(f"no worker: {params['model']}")
ret = {
"text": server_error_msg,
"error_code": 2,
}
yield json.dumps(ret).encode() + b"\0"
try:
response = requests.post(worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk + b"\0"
except requests.exceptions.RequestException as e:
logger.info(f"worker timeout: {worker_addr}")
ret = {
"text": server_error_msg,
"error_code": 3,
}
yield json.dumps(ret).encode() + b"\0"
# Let the controller act as a worker to achieve hierarchical
# management. This can be used to connect isolated sub networks.
def worker_api_get_status(self):
model_names = set()
speed = 0
queue_length = 0
for w_name in self.worker_info:
worker_status = self.get_worker_status(w_name)
if worker_status is not None:
model_names.update(worker_status["model_names"])
speed += worker_status["speed"]
queue_length += worker_status["queue_length"]
return {
"model_names": list(model_names),
"speed": speed,
"queue_length": queue_length,
}
app = FastAPI()
@app.post("/register_worker")
async def register_worker(request: Request):
data = await request.json()
controller.register_worker(data["worker_name"], data["check_heart_beat"], data.get("worker_status", None))
@app.post("/refresh_all_workers")
async def refresh_all_workers():
models = controller.refresh_all_workers()
@app.post("/list_models")
async def list_models():
models = controller.list_models()
return {"models": models}
@app.post("/get_worker_address")
async def get_worker_address(request: Request):
data = await request.json()
addr = controller.get_worker_address(data["model"])
return {"address": addr}
@app.post("/receive_heart_beat")
async def receive_heart_beat(request: Request):
data = await request.json()
exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
return {"exist": exist}
@app.post("/worker_generate_stream")
async def worker_api_generate_stream(request: Request):
params = await request.json()
generator = controller.worker_api_generate_stream(params)
return StreamingResponse(generator)
@app.post("/worker_get_status")
async def worker_api_get_status(request: Request):
return controller.worker_api_get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21001)
parser.add_argument("--dispatch-method", type=str, choices=["lottery", "shortest_queue"], default="shortest_queue")
args = parser.parse_args()
logger.info(f"args: {args}")
controller = Controller(args.dispatch_method)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
import argparse
import datetime
import json
import os
import time
import gradio as gr
import requests
from llava.conversation import default_conversation, conv_templates, SeparatorStyle
from llava.constants import LOGDIR
from llava.utils import build_logger, server_error_msg, violates_moderation, moderation_msg
import hashlib
logger = build_logger("gradio_web_server", "gradio_web_server.log")
headers = {"User-Agent": "LLaVA Client"}
no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)
priority = {
"vicuna-13b": "aaaaaaa",
"koala-13b": "aaaaaab",
}
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def get_model_list():
ret = requests.post(args.controller_url + "/refresh_all_workers")
assert ret.status_code == 200
ret = requests.post(args.controller_url + "/list_models")
models = ret.json()["models"]
models.sort(key=lambda x: priority.get(x, x))
logger.info(f"Models: {models}")
return models
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(value=model, visible=True)
state = default_conversation.copy()
return (state, dropdown_update, gr.Chatbot.update(visible=True), gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Row.update(visible=True), gr.Accordion.update(visible=True))
def load_demo_refresh_model_list(request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}")
models = get_model_list()
state = default_conversation.copy()
return (
state,
gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else ""),
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
def upvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"upvote. ip: {request.client.host}")
vote_last_response(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"downvote. ip: {request.client.host}")
vote_last_response(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response(state, model_selector, request: gr.Request):
logger.info(f"flag. ip: {request.client.host}")
vote_last_response(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
def regenerate(state, image_process_mode, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
prev_human_msg = state.messages[-2]
if type(prev_human_msg[1]) in (tuple, list):
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = default_conversation.copy()
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
def add_text(state, text, image, image2, image_process_mode, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0 and image is None:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
if args.moderate:
flagged = violates_moderation(text)
if flagged:
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5
text = text[:3072] # Hard cut-off
images = [x for x in [image, image2] if x is not None]
num_images = len(images)
if num_images > 0:
text = text.replace("<image>", "").strip()
text = text[: 3072 - 512 * num_images]
text = "<image>\n" * num_images + text
text = (text, images, image_process_mode)
if len(state.get_images(return_pil=True)) > 0:
state = default_conversation.copy()
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
# First round of conversation
if "llava" in model_name.lower():
if "llama-2" in model_name.lower():
if "sharegpt" in model_name.lower():
if "mmtag" in model_name.lower():
template_name = "v1_mmtag"
elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
template_name = "v1_mmtag"
else:
template_name = "llava_v1"
else:
if "mmtag" in model_name.lower():
template_name = "llava_llama_2_mmtag"
elif "simple" in model_name.lower():
template_name = "llava_llama_2_simple"
elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
template_name = "llava_llama_2_mmtag"
elif "simple" in model_name.lower():
template_name = "llava_llama_2_simple"
else:
template_name = "llava_llama_2"
elif "v1" in model_name.lower():
if "mmtag" in model_name.lower():
template_name = "v1_mmtag"
elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
template_name = "v1_mmtag"
else:
template_name = "llava_v1"
elif "mpt" in model_name.lower():
template_name = "mpt"
else:
if "mmtag" in model_name.lower():
template_name = "v0_mmtag"
elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
template_name = "v0_mmtag"
else:
template_name = "llava_v0"
elif "mpt" in model_name.lower():
template_name = "mpt_text"
elif "llama-2" in model_name.lower():
if "sharegpt" in model_name.lower():
template_name = "vicuna_v1"
else:
template_name = "llama_2"
else:
template_name = "vicuna_v1"
new_state = conv_templates[template_name].copy()
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
# Query worker address
controller_url = args.controller_url
ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name})
worker_addr = ret.json()["address"]
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
# No available worker
if worker_addr == "":
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
# Construct prompt
prompt = state.get_prompt()
all_images = state.get_images(return_pil=True)
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
for image, hash in zip(all_images, all_image_hash):
t = datetime.datetime.now()
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
if not os.path.isfile(filename):
os.makedirs(os.path.dirname(filename), exist_ok=True)
image.save(filename)
# Make requests
pload = {
"model": model_name,
"prompt": prompt,
"temperature": float(temperature),
"top_p": float(top_p),
"max_new_tokens": min(int(max_new_tokens), 1536),
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
"images": f"List of {len(state.get_images())} images: {all_image_hash}",
}
logger.info(f"==== request ====\n{pload}")
pload["images"] = state.get_images()
state.messages[-1][-1] = "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
# Stream output
response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=10)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][len(prompt) :].strip()
state.messages[-1][-1] = output + "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
time.sleep(0.03)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"images": all_image_hash,
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
title_markdown = """
# 🌋 LLaVA: Large Language and Vision Assistant
[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
"""
tos_markdown = """
### Terms of use
By using this service, users are required to agree to the following terms:
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
"""
learn_more_markdown = """
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
"""
block_css = """
#buttons button {
min-width: min(120px,100%);
}
#chatbot img {
display: inline-block;
}
"""
def build_demo(embed_mode):
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
state = gr.State()
if not embed_mode:
gr.Markdown(title_markdown)
with gr.Row():
with gr.Column(scale=3):
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False)
with gr.Row(elem_id="images"):
imagebox = gr.Image(type="pil")
imagebox_2 = gr.Image(type="pil")
image_process_mode = gr.Radio(["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False)
cur_dir = os.path.dirname(os.path.abspath(__file__))
gr.Examples(
examples=[
[f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
[f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
],
inputs=[imagebox, textbox],
)
with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
step=64,
interactive=True,
label="Max output tokens",
)
with gr.Column(scale=8):
chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", visible=False, height=550)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(value="Submit", visible=False)
with gr.Row(visible=False) as button_row:
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
if not embed_mode:
gr.Markdown(tos_markdown)
gr.Markdown(learn_more_markdown)
url_params = gr.JSON(visible=False)
# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
downvote_btn.click(downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
flag_btn.click(flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
regenerate_btn.click(regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox, imagebox_2] + btn_list).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list)
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, imagebox_2] + btn_list)
textbox.submit(add_text, [state, textbox, imagebox, imagebox_2, image_process_mode], [state, chatbot, textbox, imagebox, imagebox_2] + btn_list).then(
http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list
)
submit_btn.click(add_text, [state, textbox, imagebox, imagebox_2, image_process_mode], [state, chatbot, textbox, imagebox, imagebox_2] + btn_list).then(
http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list
)
if args.model_list_mode == "once":
demo.load(load_demo, [url_params], [state, model_selector, chatbot, textbox, submit_btn, button_row, parameter_row], _js=get_window_url_params)
elif args.model_list_mode == "reload":
demo.load(load_demo_refresh_model_list, None, [state, model_selector, chatbot, textbox, submit_btn, button_row, parameter_row])
else:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
parser.add_argument("--concurrency-count", type=int, default=8)
parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
models = get_model_list()
logger.info(args)
demo = build_demo(args.embed)
demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False).launch(server_name=args.host, server_port=args.port, share=args.share)
import argparse
import datetime
import json
import os
import time
import gradio as gr
import requests
from llava.conversation import default_conversation, conv_templates, SeparatorStyle
from llava.constants import LOGDIR
from llava.utils import build_logger, server_error_msg, violates_moderation, moderation_msg
import hashlib
logger = build_logger("gradio_web_server", "gradio_web_server.log")
headers = {"User-Agent": "LLaVA Client"}
no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)
priority = {
"vicuna-13b": "aaaaaaa",
"koala-13b": "aaaaaab",
}
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def get_model_list():
ret = requests.post(args.controller_url + "/refresh_all_workers")
assert ret.status_code == 200
ret = requests.post(args.controller_url + "/list_models")
models = ret.json()["models"]
models.sort(key=lambda x: priority.get(x, x))
logger.info(f"Models: {models}")
return models
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(value=model, visible=True)
state = default_conversation.copy()
return state, dropdown_update
def load_demo_refresh_model_list(request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}")
models = get_model_list()
state = default_conversation.copy()
dropdown_update = gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else "")
return state, dropdown_update
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
def upvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"upvote. ip: {request.client.host}")
vote_last_response(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"downvote. ip: {request.client.host}")
vote_last_response(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response(state, model_selector, request: gr.Request):
logger.info(f"flag. ip: {request.client.host}")
vote_last_response(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
def regenerate(state, image_process_mode, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
prev_human_msg = state.messages[-2]
if type(prev_human_msg[1]) in (tuple, list):
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = default_conversation.copy()
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def add_text(state, text, image, image_process_mode, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0 and image is None:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
if args.moderate:
flagged = violates_moderation(text)
if flagged:
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5
text = text[:1536] # Hard cut-off
if image is not None:
text = text[:1200] # Hard cut-off for images
if "<image>" not in text:
# text = '<Image><image></Image>' + text
text = text + "\n<image>"
text = (text, image, image_process_mode)
if len(state.get_images(return_pil=True)) > 0:
state = default_conversation.copy()
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request, template_name=None):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
# First round of conversation
if "llava" in model_name.lower():
if "llama-2" in model_name.lower():
template_name = "llava_llama_2"
elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
if "orca" in model_name.lower():
template_name = "mistral_orca"
elif "hermes" in model_name.lower():
template_name = "mistral_direct"
else:
template_name = "mistral_instruct"
elif "zephyr" in model_name.lower():
template_name = "mistral_zephyr"
elif "hermes" in model_name.lower():
template_name = "mistral_direct"
elif "v1" in model_name.lower():
if "mmtag" in model_name.lower():
template_name = "llava_v1_mmtag"
elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
template_name = "llava_v1_mmtag"
else:
template_name = "llava_v1"
elif "mpt" in model_name.lower():
template_name = "mpt"
else:
if "mmtag" in model_name.lower():
template_name = "v0_plain"
elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
template_name = "v0_plain"
else:
template_name = "llava_v0"
elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
if "orca" in model_name.lower():
template_name = "mistral_orca"
elif "hermes" in model_name.lower():
template_name = "mistral_direct"
else:
template_name = "mistral_instruct"
elif "hermes" in model_name.lower():
template_name = "mistral_direct"
elif "zephyr" in model_name.lower():
template_name = "mistral_zephyr"
elif "mpt" in model_name:
template_name = "mpt_text"
elif "llama-2" in model_name:
template_name = "llama_2"
else:
template_name = "vicuna_v1"
new_state = conv_templates[template_name].copy()
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
# Query worker address
controller_url = args.controller_url
ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name})
worker_addr = ret.json()["address"]
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
# No available worker
if worker_addr == "":
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
# Construct prompt
prompt = state.get_prompt()
all_images = state.get_images(return_pil=True)
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
for image, hash in zip(all_images, all_image_hash):
t = datetime.datetime.now()
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
if not os.path.isfile(filename):
os.makedirs(os.path.dirname(filename), exist_ok=True)
image.save(filename)
# Make requests
pload = {
"model": model_name,
"prompt": prompt,
"temperature": float(temperature),
"top_p": float(top_p),
"max_new_tokens": min(int(max_new_tokens), 1536),
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
"images": f"List of {len(state.get_images())} images: {all_image_hash}",
}
logger.info(f"==== request ====\n{pload}")
pload["images"] = state.get_images()
state.messages[-1][-1] = "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
# Stream output
response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=100)
last_print_time = time.time()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][len(prompt) :].strip()
state.messages[-1][-1] = output + "▌"
if time.time() - last_print_time > 0.05:
last_print_time = time.time()
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
time.sleep(0.03)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"images": all_image_hash,
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
title_markdown = """
# 🌋 LLaVA: Large Language and Vision Assistant
[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
"""
tos_markdown = """
### Terms of use
By using this service, users are required to agree to the following terms:
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
"""
learn_more_markdown = """
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
"""
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
def build_demo(embed_mode):
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
state = gr.State()
if not embed_mode:
gr.Markdown(title_markdown)
with gr.Row():
with gr.Column(scale=3):
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False)
imagebox = gr.Image(type="pil")
image_process_mode = gr.Radio(["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False)
cur_dir = os.path.dirname(os.path.abspath(__file__))
gr.Examples(
examples=[
[f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
[f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
],
inputs=[imagebox, textbox],
)
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
step=64,
interactive=True,
label="Max output tokens",
)
with gr.Column(scale=8):
chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(value="Send", variant="primary")
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
if not embed_mode:
gr.Markdown(tos_markdown)
gr.Markdown(learn_more_markdown)
url_params = gr.JSON(visible=False)
# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False)
downvote_btn.click(downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False)
flag_btn.click(flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False)
regenerate_btn.click(regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list)
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False)
textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False).then(
http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list
)
submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False).then(
http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list
)
if args.model_list_mode == "once":
demo.load(load_demo, [url_params], [state, model_selector], _js=get_window_url_params, queue=False)
elif args.model_list_mode == "reload":
demo.load(load_demo_refresh_model_list, None, [state, model_selector], queue=False)
else:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
models = get_model_list()
logger.info(args)
demo = build_demo(args.embed)
demo.queue(concurrency_count=args.concurrency_count, api_open=False).launch(server_name=args.host, server_port=args.port, share=args.share)
"""
A model worker executes the model.
"""
import argparse
import asyncio
import json
import time
import threading
import uuid
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
import requests
import torch
import uvicorn
from functools import partial
from llava.constants import WORKER_HEART_BEAT_INTERVAL
from llava.utils import build_logger, server_error_msg, pretty_print_semaphore
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from transformers import TextIteratorStreamer
from threading import Thread
GB = 1 << 30
worker_id = str(uuid.uuid4())[:6]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
global_counter = 0
model_semaphore = None
def heart_beat_worker(controller):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
controller.send_heart_beat()
class ModelWorker:
def __init__(self, controller_addr, worker_addr, worker_id, no_register, model_path, model_base, model_name, load_8bit, load_4bit):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
if model_path.endswith("/"):
model_path = model_path[:-1]
if model_name is None:
model_paths = model_path.split("/")
if model_paths[-1].startswith("checkpoint-"):
self.model_name = model_paths[-2] + "_" + model_paths[-1]
else:
self.model_name = model_paths[-1]
else:
self.model_name = model_name
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, model_base, self.model_name, load_8bit, load_4bit)
self.is_multimodal = "llava" in self.model_name.lower()
if not no_register:
self.register_to_controller()
self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,))
self.heart_beat_thread.start()
def register_to_controller(self):
logger.info("Register to controller")
url = self.controller_addr + "/register_worker"
data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
r = requests.post(url, json=data)
assert r.status_code == 200
def send_heart_beat(self):
logger.info(f"Send heart beat. Models: {[self.model_name]}. " f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " f"global_counter: {global_counter}")
url = self.controller_addr + "/receive_heart_beat"
while True:
try:
ret = requests.post(url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5)
exist = ret.json()["exist"]
break
except requests.exceptions.RequestException as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)
if not exist:
self.register_to_controller()
def get_queue_length(self):
if model_semaphore is None:
return 0
else:
return args.limit_model_concurrency - model_semaphore._value + (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
def get_status(self):
return {
"model_names": [self.model_name],
"speed": 1,
"queue_length": self.get_queue_length(),
}
@torch.inference_mode()
def generate_stream(self, params):
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
prompt = params["prompt"]
ori_prompt = prompt
images = params.get("images", None)
num_image_tokens = 0
if images is not None and len(images) > 0 and self.is_multimodal:
if len(images) > 0:
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
raise ValueError("Number of images does not match number of <image> tokens in prompt")
images = [load_image_from_base64(image) for image in images]
image_sizes = [image.size for image in images]
images = process_images(images, image_processor, model.config)
if type(images) is list:
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
else:
images = images.to(self.model.device, dtype=torch.float16)
replace_token = DEFAULT_IMAGE_TOKEN
if getattr(self.model.config, "mm_use_im_start_end", False):
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
else:
images = None
image_sizes = None
image_args = {"images": images, "image_sizes": image_sizes}
else:
images = None
image_args = {}
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
max_context_length = getattr(model.config, "max_position_embeddings", 2048)
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
stop_str = params.get("stop", None)
do_sample = True if temperature > 0.001 else False
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
if max_new_tokens < 1:
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
return
thread = Thread(
target=model.generate,
kwargs=dict(
inputs=input_ids,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
streamer=streamer,
# stopping_criteria=[stopping_criteria],
use_cache=True,
**image_args,
),
)
thread.start()
start_time = time.time()
generated_text = ori_prompt
for new_text in streamer:
generated_text += new_text
if generated_text.endswith(stop_str):
generated_text = generated_text[: -len(stop_str)]
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
end_time = time.time()
new_generated = generated_text[len(ori_prompt) :]
new_generated_tokens = tokenizer(new_generated).input_ids
token_per_second = len(new_generated_tokens) / (end_time - start_time)
print(f"token_per_second: {token_per_second}")
def generate_stream_gate(self, params):
try:
for x in self.generate_stream(params):
yield x
except ValueError as e:
print("Caught ValueError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError as e:
print("Caught torch.cuda.CudaError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
print("Caught Unknown Error", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
app = FastAPI()
def release_model_semaphore(fn=None):
model_semaphore.release()
if fn is not None:
fn()
@app.post("/worker_generate_stream")
async def generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
await model_semaphore.acquire()
worker.send_heart_beat()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_get_status")
async def get_status(request: Request):
return worker.get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--model-name", type=str)
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
parser.add_argument("--limit-model-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=1)
parser.add_argument("--no-register", action="store_true")
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
if args.multi_modal:
logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
worker = ModelWorker(args.controller_address, args.worker_address, worker_id, args.no_register, args.model_path, args.model_base, args.model_name, args.load_8bit, args.load_4bit)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
"""
Manually register workers.
Usage:
python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
"""
import argparse
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--controller-address", type=str)
parser.add_argument("--worker-name", type=str)
parser.add_argument("--check-heart-beat", action="store_true")
args = parser.parse_args()
url = args.controller_address + "/register_worker"
data = {
"worker_name": args.worker_name,
"check_heart_beat": args.check_heart_beat,
"worker_status": None,
}
r = requests.post(url, json=data)
assert r.status_code == 200
"""
A model worker executes the model.
"""
import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
import time
import threading
import uuid
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
import requests
import re
import uvicorn
from functools import partial
from llava.constants import WORKER_HEART_BEAT_INTERVAL
from llava.utils import build_logger, server_error_msg, pretty_print_semaphore
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from transformers import AutoTokenizer
import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import read_jsonl, dump_state_text
from sglang.lang.interpreter import ProgramState
GB = 1 << 30
worker_id = str(uuid.uuid4())[:6]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
global_counter = 0
model_semaphore = None
def heart_beat_worker(controller):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
controller.send_heart_beat()
@sgl.function
def pipeline(s, prompt, max_tokens):
for p in prompt:
if type(p) is str:
s += p
else:
s += sgl.image(p)
s += sgl.gen("response", max_tokens=max_tokens)
class ModelWorker:
def __init__(self, controller_addr, worker_addr, sgl_endpoint, worker_id, no_register, model_name):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
# Select backend
backend = RuntimeEndpoint(sgl_endpoint)
sgl.set_default_backend(backend)
model_path = backend.model_info["model_path"]
if model_path.endswith("/"):
model_path = model_path[:-1]
if model_name is None:
model_paths = model_path.split("/")
if model_paths[-1].startswith("checkpoint-"):
self.model_name = model_paths[-2] + "_" + model_paths[-1]
else:
self.model_name = model_paths[-1]
else:
self.model_name = model_name
logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...")
if not no_register:
self.register_to_controller()
self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,))
self.heart_beat_thread.start()
def register_to_controller(self):
logger.info("Register to controller")
url = self.controller_addr + "/register_worker"
data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
r = requests.post(url, json=data)
assert r.status_code == 200
def send_heart_beat(self):
logger.info(f"Send heart beat. Models: {[self.model_name]}. " f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " f"global_counter: {global_counter}")
url = self.controller_addr + "/receive_heart_beat"
while True:
try:
ret = requests.post(url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5)
exist = ret.json()["exist"]
break
except requests.exceptions.RequestException as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)
if not exist:
self.register_to_controller()
def get_queue_length(self):
if model_semaphore is None:
return 0
else:
return args.limit_model_concurrency - model_semaphore._value + (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
def get_status(self):
return {
"model_names": [self.model_name],
"speed": 1,
"queue_length": self.get_queue_length(),
}
async def generate_stream(self, params):
ori_prompt = prompt = params["prompt"]
images = params.get("images", None)
if images is not None and len(images) > 0:
if len(images) > 0:
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
raise ValueError("Number of images does not match number of <image> tokens in prompt")
images = [load_image_from_base64(image) for image in images]
# FIXME: hacky padding
images = [expand2square(image, tuple(int(x * 255) for x in [0.48145466, 0.4578275, 0.40821073])) for image in images]
# FIXME: for image-start/end token
# replace_token = DEFAULT_IMAGE_TOKEN
# if getattr(self.model.config, 'mm_use_im_start_end', False):
# replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
# prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
prompt = prompt.replace(" " + DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN)
prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
prompt = []
for i in range(len(prompt_split)):
prompt.append(prompt_split[i])
if i < len(images):
prompt.append(images[i])
else:
prompt = [prompt]
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
# max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
stop_str = params.get("stop", None)
stop_str = [stop_str] if stop_str is not None else None
if max_new_tokens < 1:
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
return
# print(prompt)
state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True)
generated_text = ori_prompt
async for text_outputs in state.text_async_iter(var_name="response"):
generated_text += text_outputs
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
async def generate_stream_gate(self, params):
try:
async for x in self.generate_stream(params):
yield x
except ValueError as e:
print("Caught ValueError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
print("Caught Unknown Error", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
app = FastAPI()
def release_model_semaphore(fn=None):
model_semaphore.release()
if fn is not None:
fn()
@app.post("/worker_generate_stream")
async def generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
await model_semaphore.acquire()
worker.send_heart_beat()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_get_status")
async def get_status(request: Request):
return worker.get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
parser.add_argument("--model-name", type=str)
parser.add_argument("--sgl-endpoint", type=str)
parser.add_argument("--limit-model-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=1)
parser.add_argument("--no-register", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
worker = ModelWorker(args.controller_address, args.worker_address, args.sgl_endpoint, worker_id, args.no_register, args.model_name)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
import argparse
import json
import requests
from llava.conversation import default_conversation
def main():
if args.worker_address:
worker_addr = args.worker_address
else:
controller_addr = args.controller_address
ret = requests.post(controller_addr + "/refresh_all_workers")
ret = requests.post(controller_addr + "/list_models")
models = ret.json()["models"]
models.sort()
print(f"Models: {models}")
ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name})
worker_addr = ret.json()["address"]
print(f"worker_addr: {worker_addr}")
if worker_addr == "":
return
conv = default_conversation.copy()
conv.append_message(conv.roles[0], args.message)
prompt = conv.get_prompt()
headers = {"User-Agent": "LLaVA Client"}
pload = {
"model": args.model_name,
"prompt": prompt,
"max_new_tokens": args.max_new_tokens,
"temperature": 0.7,
"stop": conv.sep,
}
response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True)
print(prompt.replace(conv.sep, "\n"), end="")
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"].split(conv.sep)[-1]
print(output, end="\r")
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
parser.add_argument("--worker-address", type=str)
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--max-new-tokens", type=int, default=32)
parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.")
args = parser.parse_args()
main()
from typing import Optional, Tuple
import warnings
import torch
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.")
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transform the data into the format required by flash attention
qkv = torch.stack([query_states, key_states, value_states], dim=2)
qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
max_s = q_len
output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
output = output.view(bsz, q_len, -1)
else:
qkv = qkv.reshape(bsz, q_len, -1)
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len)
return self.o_proj(output), None, past_key_value
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593")
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
import os
import torch
import torch.nn as nn
import datetime
from accelerate import Accelerator
from accelerate.utils import InitProcessGroupKwargs, GradientAccumulationPlugin
from torch.utils.data import Dataset, Sampler, DataLoader
from trl.trainer import DPOTrainer
from trl.trainer.utils import DPODataCollatorWithPadding
from transformers import Trainer
from transformers.trainer import is_sagemaker_mp_enabled, get_parameter_names, has_length, ALL_LAYERNORM_LAYERS, logger, is_accelerate_available, is_datasets_available, GradientAccumulationPlugin
from transformers.trainer_utils import seed_worker
from transformers.trainer_pt_utils import get_length_grouped_indices as get_length_grouped_indices_hf
from transformers.trainer_pt_utils import AcceleratorConfig
from typing import List, Optional
from datetime import timedelta
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches, InitProcessGroupKwargs
if is_datasets_available():
import datasets
from llava.utils import rank0_print
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, "no ignore status")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
return to_return
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
return chunks
def get_variable_length_grouped_indices(lengths, batch_size, world_size, megabatch_mult=8, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True)
megabatch_size = world_size * batch_size * megabatch_mult
megabatches = [sorted_indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: indices[i], reverse=True) for megabatch in megabatches]
shuffled_indices = [i for megabatch in megabatches for i in megabatch]
world_batch_size = world_size * batch_size
batches = [shuffled_indices[i : i + world_batch_size] for i in range(0, len(lengths), world_batch_size)]
batch_indices = torch.randperm(len(batches), generator=generator)
batches = [batches[i] for i in batch_indices]
return [i for batch in batches for i in batch]
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
"""
Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
lengths. To do this, the indices are:
- randomly permuted
- grouped in mega-batches of size `mega_batch_mult * batch_size`
- reorder by length in each mega-batch
The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
maximum length placed first, so that an OOM happens sooner rather than later.
"""
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
megabatch_size = world_size * batch_size
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
if len(additional_batch) > 0:
megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
"""
Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
lengths. To do this, the indices are:
- randomly permuted
- grouped in mega-batches of size `mega_batch_mult * batch_size`
- reorder by length in each mega-batch
The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
maximum length placed first, so that an OOM happens sooner rather than later.
"""
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
def get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=None):
indices = get_length_grouped_indices_hf(lengths, batch_size * world_size, generator=generator)
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
batch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in batch_indices]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
def get_modality_length_grouped_indices_auto(lengths, batch_size, world_size, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices_auto_single(mm_lengths, batch_size, world_size, generator=None)]
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices_auto_single(lang_lengths, batch_size, world_size, generator=None)]
megabatch_size = world_size * batch_size
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
# FIXME: Hard code to avoid last batch mixed with different modalities
# if len(additional_batch) > 0:
# megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
lengths: Optional[List[int]] = None,
generator=None,
variable_length: bool = False,
group_by_modality: bool = False,
group_by_modality_auto: bool = False,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.world_size = world_size
self.lengths = lengths
self.generator = generator
self.variable_length = variable_length
self.group_by_modality = group_by_modality
self.group_by_modality_auto = group_by_modality_auto
def __len__(self):
return len(self.lengths)
def __iter__(self):
if self.variable_length:
assert not self.group_by_modality, "Variable length grouping is not supported with modality grouping."
indices = get_variable_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
else:
if self.group_by_modality:
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
elif self.group_by_modality_auto:
indices = get_modality_length_grouped_indices_auto(self.lengths, self.batch_size, self.world_size, generator=self.generator)
else:
indices = get_length_grouped_indices_auto_single(self.lengths, self.batch_size, self.world_size, generator=self.generator)
return iter(indices)
class LLaVATrainer(Trainer):
def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
rank0_print("Setting NCCL timeout to INF to avoid running errors.")
# create accelerator object
self.accelerator = Accelerator(
dispatch_batches=self.args.dispatch_batches, split_batches=self.args.split_batches, deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin, kwargs_handlers=[accelerator_kwargs]
)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", fsdp_plugin.limit_all_gathers)
if is_accelerate_available("0.23.0"):
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get("activation_checkpointing", fsdp_plugin.activation_checkpointing)
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
raise ValueError("The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " "when using FSDP.")
if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
self.propagate_args_to_deepspeed()
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_length:
lengths = self.train_dataset.lengths
return LengthGroupedSampler(
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
self.args.train_batch_size,
# world_size=self.args.world_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
lengths=lengths,
)
elif self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
self.args.train_batch_size,
# world_size=self.args.world_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
lengths=lengths,
group_by_modality=True,
)
elif self.args.group_by_modality_length_auto:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
self.args.train_batch_size,
# world_size=self.args.world_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
lengths=lengths,
group_by_modality_auto=True,
)
elif self.args.group_by_varlen:
lengths = self.train_dataset.lengths
return LengthGroupedSampler(
self.args.train_batch_size * self.args.gradient_accumulation_steps,
# self.args.train_batch_size, # TODO: seems that we should have gradient_accumulation_steps
# world_size=self.args.world_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
lengths=lengths,
variable_length=True,
)
else:
return super()._get_train_sampler()
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_num_workers * 2 if self.args.dataloader_num_workers != 0 else None
dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
return dataloader
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
lr_mapper = {}
if self.args.mm_projector_lr is not None:
lr_mapper["mm_projector"] = self.args.mm_projector_lr
if self.args.mm_vision_tower_lr is not None:
lr_mapper["vision_tower"] = self.args.mm_vision_tower_lr
if len(lr_mapper) > 0:
special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)]
optimizer_grouped_parameters = [
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
"weight_decay": 0.0,
},
]
for module_keyword, lr in lr_mapper.items():
module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name]
optimizer_grouped_parameters.extend(
[
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
"lr": lr,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)],
"weight_decay": 0.0,
"lr": lr,
},
]
)
else:
optimizer_grouped_parameters = [
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
return self.optimizer
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, "tune_mm_mlp_adapter", False) or (
hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts))
):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
# Only save Adapter
keys_to_match = ["mm_projector", "vision_resampler"]
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
else:
super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, "tune_mm_mlp_adapter", False):
pass
else:
super(LLaVATrainer, self)._save(output_dir, state_dict)
class LLaVADPOTrainer(DPOTrainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
self.args.train_batch_size,
world_size=self.args.world_size,
lengths=lengths,
group_by_modality=True,
)
else:
return super()._get_train_sampler()
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, "tune_mm_mlp_adapter", False) or (
hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts))
):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
# Only save Adapter
keys_to_match = ["mm_projector", "vision_resampler"]
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
else:
# super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics)
# print(type(model))
# from transformers.modeling_utils import unwrap_model
# print(type(unwrap_model(model)))
# print(unwrap_model(model).config)
if self.args.lora_enable:
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
from transformers.modeling_utils import unwrap_model
unwrapped_model = unwrap_model(model)
self.save_my_lora_ckpt(output_dir, self.args, unwrapped_model)
else:
super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, "tune_mm_mlp_adapter", False):
pass
else:
super(LLaVADPOTrainer, self)._save(output_dir, state_dict)
import json
import subprocess
from llava.train.llava_trainer import LLaVATrainer
class LLaVAEvalTrainer(LLaVATrainer):
def evaluate(self, evaluate_args):
cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \
--model {evaluate_args.model} \
--model_args {evaluate_args.model_args} \
--tasks {evaluate_args.task_names} \
--batch_size {evaluate_args.batch_size} \
--log_samples_suffix {evaluate_args.log_samples_suffix} \
--output_path {evaluate_args.output_path}"
if evaluate_args.limit:
cmd += f" --limit {evaluate_args.limit}"
if evaluate_args.num_fewshot:
cmd += f" --num_fewshot {evaluate_args.num_fewshot}"
if evaluate_args.gen_kwargs != "":
cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}"
if evaluate_args.log_samples:
cmd += f" --log_samples"
else:
assert False, "Please log samples so that the result can be parsed"
results = subprocess.run([cmd], shell=True, capture_output=True, text=True)
try:
result_file_index_start = results.stdout.index("Saved samples to ")
result_file_index_end = results.stdout.index(f".json")
result_file_index_start += len("Saved samples to ")
file = results.stdout[result_file_index_start:result_file_index_end]
except:
result_file_index_start = results.stderr.index("Saved samples to ")
result_file_index_end = results.stderr.index(f".json")
result_file_index_start += len("Saved samples to ")
file = results.stderr[result_file_index_start:result_file_index_end]
file = file.split("/")[:-1]
file = "/".join(file) + "/results.json"
with open(file, "r") as f:
lmms_eval_results = json.load(f)
result_dict = {}
tasks_list = evaluate_args.task_names.split(",")
for task in tasks_list:
task_results = lmms_eval_results["results"][task]
for k, v in task_results.items():
if k != "alias" and "stderr" not in k:
metric = k.split(",")[0]
result_dict[f"{task}_{metric}"] = v
return result_dict
"""def evaluate(self, evaluate_args):
initialize_tasks()
tasks_list = evaluate_args.task_names.split(",")
result_dict = {}
results = evaluator.simple_evaluate(
model=evaluate_args.model,
model_args=evaluate_args.model_args,
tasks=tasks_list,
num_fewshot=evaluate_args.num_fewshot,
batch_size=evaluate_args.batch_size,
device=evaluate_args.device,
limit=evaluate_args.limit,
check_integrity=evaluate_args.check_integrity,
show_task_to_terminal=evaluate_args.show_task_to_terminal,
log_samples=evaluate_args.log_samples,
gen_kwargs=evaluate_args.gen_kwargs,
cli_args=evaluate_args,
)
for task in tasks_list:
task_results = results["results"][task]
for k,v in task_results.items():
if k != "alias" and "stderr" not in k:
metric = k.split(",")[0]
result_dict[f"{task}_{metric}"] = v
return result_dict"""
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
from PIL import Image, ImageFile
from packaging import version
import numpy as np
import time
import random
import yaml
import math
import re
import torch
import transformers
import tokenizers
import deepspeed
from transformers import AutoConfig
from torch.utils.data import Dataset
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
from llava.train.llava_trainer import LLaVATrainer
from llava import conversation as conversation_lib
from llava.model import *
from llava.mm_utils import process_highres_image, process_anyres_image, process_highres_image_crop_split, tokenizer_image_token
from llava.utils import rank0_print, process_video_with_pyav, process_video_with_decord
torch.multiprocessing.set_sharing_strategy("file_system")
ImageFile.LOAD_TRUNCATED_IMAGES = True
local_rank = None
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse("0.14")
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
model_class_name: Optional[str] = field(default=None, metadata={"help": "Used to init model class, format is XXXXForCausalLM. e.g. currently XXXX is chosen from LlavaLlama, LlavaMixtral, LlavaMistral, Llama"})
mm_tunable_parts: Optional[str] = field(
default=None, metadata={"help": 'Could be "mm_mlp_adapter", "mm_vision_resampler", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_mlp_adapter,mm_language_model"'}
)
# deciding which part of the multimodal model to tune, will overwrite other previous settings
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=False)
tune_mm_mlp_adapter: bool = field(default=False)
tune_mm_vision_resampler: bool = field(default=False)
vision_tower: Optional[str] = field(default=None)
vision_tower_pretrained: Optional[str] = field(default=None) # default to the last layer
unfreeze_mm_vision_tower: bool = field(default=False)
unfreeze_language_model: bool = field(default=False)
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
mm_projector_type: Optional[str] = field(default="linear")
mm_use_im_start_end: bool = field(default=False)
mm_use_im_patch_token: bool = field(default=True)
mm_patch_merge_type: Optional[str] = field(default="flat")
mm_vision_select_feature: Optional[str] = field(default="patch")
mm_resampler_type: Optional[str] = field(default=None)
mm_mask_drop_mode: str = field(default="fixed")
mm_mask_drop_skip_percentage: float = field(default=0.0)
mm_mask_drop_ratio: float = field(default=0.25)
mm_mask_drop_ratio_upper: Optional[float] = field(default=None)
mm_mask_drop_ratio_lower: Optional[float] = field(default=None)
mm_spatial_pool_stride: Optional[int] = field(default=None)
mm_spatial_pool_mode: str = field(default="bilinear")
mm_spatial_pool_out_channels: Optional[int] = field(default=None)
mm_perceiver_depth: Optional[int] = field(default=3)
mm_perceiver_latents: Optional[int] = field(default=32)
mm_perceiver_ff_mult: Optional[float] = field(default=4)
mm_perceiver_pretrained: Optional[str] = field(default=None)
mm_qformer_depth: Optional[int] = field(default=3)
mm_qformer_latents: Optional[int] = field(default=32)
mm_qformer_pretrained: Optional[str] = field(default=None)
rope_scaling_factor: Optional[float] = field(default=None)
rope_scaling_type: Optional[str] = field(default=None)
s2: Optional[bool] = field(default=False)
s2_scales: Optional[str] = field(default="336,672,1008")
use_pos_skipping: Optional[bool] = field(default=False)
pos_skipping_range: Optional[int] = field(default=4096)
mm_newline_position: Optional[str] = field(default="grid")
delay_load: Optional[bool] = field(default=True)
add_faster_video: Optional[bool] = field(default=False)
faster_token_stride: Optional[int] = field(default=10)
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data, in llava's instruction.json format. Supporting multiple json files via /path/to/{a,b,c}.json"})
lazy_preprocess: bool = False
is_multimodal: bool = False
early_mix_text: bool = False
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = "square"
image_grid_pinpoints: Optional[str] = field(default=None)
image_crop_resolution: Optional[int] = field(default=None)
image_split_resolution: Optional[int] = field(default=None)
video_folder: Optional[str] = field(default=None)
video_fps: Optional[int] = field(default=1)
frames_upbound: Optional[int] = field(default=0)
add_time_instruction: Optional[bool] = field(default=False)
force_sample: Optional[bool] = field(default=False)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
freeze_mm_mlp_adapter: bool = field(default=False)
freeze_mm_vision_resampler: bool = field(default=False)
mpt_attn_impl: Optional[str] = field(default="triton")
model_max_length: int = field(
default=4096,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
double_quant: bool = field(default=True, metadata={"help": "Compress the quantization statistics through double quantization."})
quant_type: str = field(default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."})
bits: int = field(default=16, metadata={"help": "How many bits to use."})
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
mm_projector_lr: Optional[float] = None
mm_vision_tower_lr: Optional[float] = None
group_by_varlen: bool = field(default=False)
group_by_modality_length: bool = field(default=False)
group_by_modality_length_auto: bool = field(default=False)
auto_find_batch_size: bool = field(default=False)
gradient_checkpointing: bool = field(default=True)
verbose_logging: bool = field(default=False)
attn_implementation: str = field(default="flash_attention_2", metadata={"help": "Use transformers attention implementation."})
# @dataclass
# class EvaluationArguments:
# eval_num_processes: int = field(default=1)
# task_names: str = field(default=None)
# model: str = field(default="llava")
# model_args: Optional[str] = field(default=None)
# num_fewshot: Optional[int] = field(default=None)
# batch_size: int = field(default=1)
# device: Optional[str] = field(default=None)
# limit: Optional[int] = field(default=None)
# check_integrity: Optional[bool] = field(default=False)
# show_task_to_terminal: Optional[bool] = field(default=False)
# log_samples: Optional[bool] = field(default=True)
# gen_kwargs: Optional[str] = field(default="")
# log_samples_suffix: Optional[str] = field(default="")
# output_path: Optional[str] = field(default="./logs/")
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
if hasattr(trainer.args, "tune_mm_mlp_adapter") and trainer.args.tune_mm_mlp_adapter:
check_only_save_mm_adapter_tunnable = True
# only has mm_mlp_adapter and mm_vision_resampler in the tuneable parts
elif hasattr(trainer.args, "mm_tunable_parts") and (len(trainer.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in trainer.args.mm_tunable_parts or "mm_vision_resampler" in trainer.args.mm_tunable_parts)):
check_only_save_mm_adapter_tunnable = True
else:
check_only_save_mm_adapter_tunnable = False
trainer.accelerator.wait_for_everyone()
torch.cuda.synchronize()
rank0_print(f"Only save projectors: {check_only_save_mm_adapter_tunnable}")
if check_only_save_mm_adapter_tunnable:
# Only save Adapter
keys_to_match = ["mm_projector", "vision_resampler"]
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"))
else:
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
return
if trainer.deepspeed:
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(target, tokenized_lens, speakers):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker == "human":
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def _add_speaker_and_signal(header, source, get_conversation=True):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"
conversation = header
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = conversation_lib.default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = conversation_lib.default_conversation.roles[1]
else:
from_str = "unknown"
sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
if get_conversation:
conversation += sentence["value"]
conversation += BEGIN_SIGNAL
return conversation
def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
# TODO maybe this should be changed for interleaved data?
# if DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN):
# only check for num_im=1
num_im = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
if num_im == 1 and DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN):
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
sentence["value"] = sentence["value"].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>")
replace_token = DEFAULT_IMAGE_TOKEN
if data_args.mm_use_im_start_end:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
# For videoInstruct-100k noisy_data. TODO: Ask Yuanhan to clean the data instead of leaving the noise code here.
sentence["value"] = sentence["value"].replace("QA_GT_caption_based_noisy", "")
return sources
def preprocess_llama_2(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
# Mask targets
sep = "[/INST] "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_gemma(sources: List[List[Dict[str, str]]], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
conv: conversation_lib.Conversation = conversation_lib.default_conversation.copy()
roles: Dict[str, str] = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations: List[str] = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source: List[Dict[str, str]] = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role: str = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids: torch.Tensor = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
else:
input_ids: torch.Tensor = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets: torch.Tensor = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.GEMMA
# Mask target
sep: str = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len: int = int(target.ne(tokenizer.pad_token_id).sum())
rounds: List[str] = conversation.split(conv.sep)
re_rounds = []
for conv_idx in range(0, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2]))
cur_len = 1 # Ignore <bos>
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep # Re-append sep because split on this
# Now "".join(parts)==rou
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 # Ignore <bos>
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 # Ignore <bos>
else:
round_len = len(tokenizer(rou).input_ids) - 1 # Ignore <bos>
instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # Ignore <bos>
round_len += 2 # sep: <end_of_turn>\n takes 2 tokens
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"warning: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
# roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
roles = {"human": "user", "gpt": "assistant"}
# Add image tokens to tokenizer as a special tokens
# Use a deepcopy of tokenizer so that we don't modify on the tokenizer
tokenizer = copy.deepcopy(tokenizer)
# When there is actually an image, we add the image tokens as a special token
if has_image:
tokenizer.add_tokens(["<image>"], special_tokens=True)
image_token_index = tokenizer.convert_tokens_to_ids("<image>")
im_start, im_end = tokenizer.additional_special_tokens_ids
# unmask_tokens = ["<|im_start|>", "<|im_start|>", "\n"]
unmask_tokens_idx = [198, im_start, im_end]
nl_tokens = tokenizer("\n").input_ids
# Reset Qwen chat templates so that it won't include system message every time we apply
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer.chat_template = chat_template
# _system = tokenizer("system").input_ids + nl_tokens
# _user = tokenizer("user").input_ids + nl_tokens
# _assistant = tokenizer("assistant").input_ids + nl_tokens
# Apply prompt templates
input_ids, targets = [], []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
# New version, use apply chat template
# Build system message for each sentence
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
target += [IGNORE_INDEX] * len(input_id)
for conv in source:
# Make sure llava data can load
try:
role = conv["role"]
content = conv["content"]
except:
role = conv["from"]
content = conv["value"]
role = roles.get(role, role)
conv = [{"role" : role, "content" : content}]
encode_id = tokenizer.apply_chat_template(conv)
input_id += encode_id
if role in ["user", "system"]:
target += [IGNORE_INDEX] * len(encode_id)
else:
target += encode_id
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
for idx, encode_id in enumerate(input_id):
if encode_id in unmask_tokens_idx:
target[idx] = encode_id
if encode_id == image_token_index:
input_id[idx] = IMAGE_TOKEN_INDEX
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(
input_ids=input_ids, # tensor(bs x seq_len)
labels=targets, # tensor(bs x seq_len)
)
def preprocess_llama3(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
max_len=2048,
system_message: str = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
) -> Dict:
# roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
roles = {"human": "user", "gpt": "assistant"}
# Add image tokens to tokenizer as a special tokens
# Use a deepcopy of tokenizer so that we don't modify on the tokenizer
tokenizer = copy.deepcopy(tokenizer)
# When there is actually an image, we add the image tokens as a special token
if has_image:
tokenizer.add_tokens(["<image>"], special_tokens=True)
image_token_index = tokenizer.convert_tokens_to_ids("<image>")
bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
unmask_tokens = ["<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "\n\n"]
unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens]
# After update, calling tokenizer of llama3 will
# auto add bos id for the tokens. ヽ(`⌒´)ノ
def safe_tokenizer_llama3(text):
input_ids = tokenizer(text).input_ids
if input_ids[0] == bos_token_id:
input_ids = input_ids[1:]
return input_ids
nl_tokens = tokenizer.convert_tokens_to_ids("\n\n")
# Apply prompt templates
input_ids, targets = [], []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
# New version, use apply chat template
# Build system message for each sentence
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
target += [IGNORE_INDEX] * len(input_id)
for conv in source:
# Make sure llava data can load
try:
role = conv["role"]
content = conv["content"]
except:
role = conv["from"]
content = conv["value"]
role = roles.get(role, role)
conv = [{"role" : role, "content" : content}]
# First is bos token we don't need here
encode_id = tokenizer.apply_chat_template(conv)[1:]
input_id += encode_id
if role in ["user", "system"]:
target += [IGNORE_INDEX] * len(encode_id)
else:
target += encode_id
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
for idx, encode_id in enumerate(input_id):
if encode_id in unmask_tokens_idx:
target[idx] = encode_id
if encode_id == image_token_index:
input_id[idx] = IMAGE_TOKEN_INDEX
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(
input_ids=input_ids, # tensor(bs x seq_len)
labels=targets, # tensor(bs x seq_len)
)
def preprocess_v1(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
round_len -= 1
instruction_len -= 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_mpt(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
# Mask targets
sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
if i != 0 and getattr(tokenizer, "legacy", False) and IS_TOKENIZER_GREATER_THAN_0_14:
round_len += 1
instruction_len += 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f"(#turns={len(re_rounds)} ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
source[0]["value"] = DEFAULT_IMAGE_TOKEN
conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
conversations.append(conversation)
# tokenize conversations
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
return preprocess_llama_2(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version.startswith("v1"):
return preprocess_v1(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "mpt":
return preprocess_mpt(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "qwen":
return preprocess_qwen(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "gemma":
return preprocess_gemma(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "llama_v3":
return preprocess_llama3(sources, tokenizer, has_image=has_image)
# add end signal and concatenate together
conversations = []
for source in sources:
header = f"{conversation_lib.default_conversation.system}\n\n"
conversation = _add_speaker_and_signal(header, source)
conversations.append(conversation)
# tokenize conversations
def get_tokenize_len(prompts):
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
if has_image:
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
else:
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
if has_image:
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
else:
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
speakers = [sentence["from"] for sentence in source]
_mask_targets(target, tokenized_lens, speakers)
return dict(input_ids=input_ids, labels=targets)
class LazySupervisedDataset(Dataset):
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
super(LazySupervisedDataset, self).__init__()
self.tokenizer = tokenizer
self.list_data_dict = []
# Handle multiple JSON files specified in the data_path
if "{" in data_path and "}" in data_path:
base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups()
file_names = file_pattern.split(",")
rank0_print(f"Loading {file_names} from {base_path}")
data_args.dataset_paths = []
for file_name in file_names:
data_args.dataset_paths.append(f"{base_path}{file_name}.json")
full_path = f"{base_path}{file_name}.json"
rank0_print(f"Loading {full_path}")
with open(full_path, "r") as file:
cur_data_dict = json.load(file)
rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}")
self.list_data_dict.extend(cur_data_dict)
elif data_path.endswith(".yaml"):
with open(data_path, "r") as file:
yaml_data = yaml.safe_load(file)
datasets = yaml_data.get("datasets")
# file should be in the format of:
# datasets:
# - json_path: xxxx1.json
# sampling_strategy: first:1000
# - json_path: xxxx2.json
# sampling_strategy: end:3000
# - json_path: xxxx3.json
# sampling_strategy: random:999
data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets]
for dataset in datasets:
json_path = dataset.get("json_path")
sampling_strategy = dataset.get("sampling_strategy", "all")
sampling_number = None
rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy")
if json_path.endswith(".jsonl"):
cur_data_dict = []
with open(json_path, "r") as json_file:
for line in json_file:
cur_data_dict.append(json.loads(line.strip()))
elif json_path.endswith(".json"):
with open(json_path, "r") as json_file:
cur_data_dict = json.load(json_file)
else:
raise ValueError(f"Unsupported file type: {json_path}")
if ":" in sampling_strategy:
sampling_strategy, sampling_number = sampling_strategy.split(":")
if "%" in sampling_number:
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
else:
sampling_number = int(sampling_number)
# Apply the sampling strategy
if sampling_strategy == "first" and sampling_number is not None:
cur_data_dict = cur_data_dict[:sampling_number]
elif sampling_strategy == "end" and sampling_number is not None:
cur_data_dict = cur_data_dict[-sampling_number:]
elif sampling_strategy == "random" and sampling_number is not None:
random.shuffle(cur_data_dict)
cur_data_dict = cur_data_dict[:sampling_number]
rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
self.list_data_dict.extend(cur_data_dict)
else:
data_args.dataset_paths = [data_path]
rank0_print(f"Loading {data_path}")
with open(data_path, "r") as file:
cur_data_dict = json.load(file)
rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}")
self.list_data_dict.extend(cur_data_dict)
rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if "image" in sample else 0
length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
assert cur_len > 0, f"Conversation length is 0 for {sample}"
if "image" in sample or "video" in sample or self.data_args.early_mix_text:
length_list.append(cur_len)
else:
length_list.append(-cur_len)
return length_list
def process_image(self, image_file, overwrite_image_aspect_ratio=None):
image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
# print(f"\n\nInspecting the image path, folder = {image_folder}, image={image_file}\n\n")
try:
image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
except Exception as exn:
print(f"Failed to open image {image_file}. Exception:", exn)
raise exn
image_size = image.size
image_aspect_ratio = self.data_args.image_aspect_ratio
if overwrite_image_aspect_ratio is not None:
image_aspect_ratio = overwrite_image_aspect_ratio
if image_aspect_ratio == "highres":
image = process_highres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
elif image_aspect_ratio == "crop_split":
image = process_highres_image_crop_split(image, self.data_args)
elif image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
else:
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
return image, image_size, "image"
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# TODO: define number of retries somewhere else
num_base_retries = 3
num_final_retries = 300
# try the current sample first
for attempt_idx in range(num_base_retries):
try:
sample = self._get_item(i)
return sample
except Exception as e:
# sleep 1s in case it is a cloud disk issue
print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
time.sleep(1)
# try other samples, in case it is file corruption issue
for attempt_idx in range(num_base_retries):
try:
next_index = min(i + 1, len(self.list_data_dict) - 1)
# sample_idx = random.choice(range(len(self)))
sample = self._get_item(next_index)
return sample
except Exception as e:
# no need to sleep
print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e)
pass
try:
sample = self._get_item(i)
return sample
except Exception as e:
raise e
def _get_item(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if "image" in sources[0]:
image_file = self.list_data_dict[i]["image"]
if type(image_file) is list:
image = [self.process_image(f) for f in image_file]
# Handling multi images
# overwrite to process with simple pad
if len(image_file) > 1:
image = [self.process_image(f, "pad") for f in image_file]
image = [[im[0], im[1], "image"] for im in image]
else:
image = [self.process_image(image_file)]
sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
elif "video" in sources[0]:
video_file = self.list_data_dict[i]["video"]
video_folder = self.data_args.video_folder
video_file = os.path.join(video_folder, video_file)
suffix = video_file.split(".")[-1]
if not os.path.exists(video_file):
print("File {} not exist!".format(video_file))
try:
if "shareVideoGPTV" in video_file:
frame_files = [os.path.join(video_file, f) for f in os.listdir(video_file) if os.path.isfile(os.path.join(video_file, f))]
frame_files.sort() # Ensure the frames are sorted if they are named sequentially
# TODO: Hard CODE: Determine the indices for uniformly sampling 10 frames
if self.data_args.force_sample:
num_frames_to_sample = self.data_args.frames_upbound
else:
num_frames_to_sample = 10
avg_fps = 2
total_frames = len(frame_files)
sampled_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int)
frame_time = [i/2 for i in sampled_indices]
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
video_time = total_frames / avg_fps
# Read and store the sampled frames
video = []
for idx in sampled_indices:
frame_path = frame_files[idx]
try:
with Image.open(frame_path) as img:
frame = img.convert("RGB")
video.append(frame)
except IOError:
print(f"Failed to read frame at path: {frame_path}")
else:
video, video_time, frame_time, num_frames_to_sample = process_video_with_decord(video_file, self.data_args)
processor = self.data_args.image_processor
image = processor.preprocess(video, return_tensors="pt")["pixel_values"]
if self.data_args.add_time_instruction:
time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {num_frames_to_sample} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video."
sources[0]["conversations"][0]["value"] = f'{DEFAULT_IMAGE_TOKEN}\n{time_instruciton}\n{sources[0]["conversations"][0]["value"].replace(DEFAULT_IMAGE_TOKEN, "")}'
image = [(image, video[0].size, "video")]
sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
# print(sources)
except Exception as e:
print(f"Error: {e}")
print(f"Failed to read video file: {video_file}")
return self._get_item(i + 1)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
has_image = ("image" in self.list_data_dict[i]) or ("video" in self.list_data_dict[i])
data_dict = preprocess(sources, self.tokenizer, has_image=has_image)
if "prompt" in data_dict:
prompt = data_dict["prompt"]
else:
prompt = None
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
# image exist in the data
if "image" in self.list_data_dict[i]:
data_dict["image"] = image
elif "video" in self.list_data_dict[i]:
data_dict["image"] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict["image"] = [
(torch.zeros(1, 3, crop_size["height"], crop_size["width"]), (crop_size["width"], crop_size["height"]), "text"),
]
# prompt exist in the data
if prompt is not None:
data_dict["prompt"] = prompt
data_dict["id"] = self.list_data_dict[i].get("id", i)
return data_dict
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def pad_sequence(self, input_ids, batch_first, padding_value):
if self.tokenizer.padding_side == "left":
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
if self.tokenizer.padding_side == "left":
input_ids = torch.flip(input_ids, [1])
return input_ids
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
# input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "id"))
input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids]
labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels]
if self.tokenizer.pad_token_id is None:
# self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # FIXME: this could only be triggered for llama3 model.
self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why.
input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
batch = dict(input_ids=input_ids, labels=labels.long() if labels.dtype == torch.int32 else labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id))
# batch = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ids=ids)
if "image" in instances[0]:
images = [instance["image"] for instance in instances]
batch["image_sizes"] = [im[1] for im_list in images for im in im_list]
batch["modalities"] = [im[2] for im_list in images for im in im_list]
images = [im[0] for im_list in images for im in im_list]
# if all(x is not None and x.shape == images[0].shape for x in images):
# Image: (N, P, C, H, W)
# Video: (N, F, C, H, W)
# batch["images"] = torch.stack(images)
# else:
batch["images"] = images
if "prompt" in instances[0]:
batch["prompts"] = [instance["prompt"] for instance in instances]
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
def get_model(model_args, training_args, bnb_model_from_pretrained_args):
assert training_args.attn_implementation
if training_args.attn_implementation == "sdpa" and torch.__version__ < "2.1.2":
raise ValueError("The 'sdpa' attention implementation requires torch version 2.1.2 or higher.")
customized_kwargs = dict()
customized_kwargs.update(bnb_model_from_pretrained_args)
cfg_pretrained = None
overwrite_config = {}
if any(
[
model_args.rope_scaling_factor is not None,
model_args.rope_scaling_type is not None,
model_args.mm_spatial_pool_stride is not None,
model_args.mm_spatial_pool_out_channels is not None,
model_args.mm_spatial_pool_mode is not None,
model_args.mm_resampler_type is not None,
]
):
cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path)
if model_args.use_pos_skipping is not None and model_args.pos_skipping_range is not None:
overwrite_config["use_pos_skipping"] = model_args.use_pos_skipping
overwrite_config["pos_skipping_range"] = model_args.pos_skipping_range
if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None:
overwrite_config["rope_scaling"] = {
"factor": model_args.rope_scaling_factor,
"type": model_args.rope_scaling_type,
}
if training_args.model_max_length is None:
training_args.model_max_length = cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor
overwrite_config["max_sequence_length"] = training_args.model_max_length
assert training_args.model_max_length == int(cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor), print(
f"model_max_length: {training_args.model_max_length}, max_position_embeddings: {cfg_pretrained.max_position_embeddings}, rope_scaling_factor: {model_args.rope_scaling_factor}"
)
# overwrite_config["max_sequence_length"] = model_args.max_sequence_length
# overwrite_config["tokenizer_model_max_length"] = model_args.tokenizer_model_max_length
if model_args.mm_spatial_pool_stride is not None and model_args.mm_spatial_pool_out_channels is not None and model_args.mm_spatial_pool_mode is not None and model_args.mm_resampler_type is not None:
overwrite_config["mm_resampler_type"] = model_args.mm_resampler_type
overwrite_config["mm_spatial_pool_stride"] = model_args.mm_spatial_pool_stride
overwrite_config["mm_spatial_pool_out_channels"] = model_args.mm_spatial_pool_out_channels
overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode
if model_args.mm_spatial_pool_mode is not None:
overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode
if overwrite_config:
assert cfg_pretrained is not None, "cfg_pretrained is None"
rank0_print(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(cfg_pretrained, k, v)
customized_kwargs["config"] = cfg_pretrained
if model_args.model_class_name is not None:
actual_model_class_name = f"{model_args.model_class_name}ForCausalLM"
model_class = getattr(transformers, actual_model_class_name)
rank0_print(f"Using model class {model_class} from {model_args.model_class_name}")
model = model_class.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
elif model_args.vision_tower is not None:
if "mixtral" in model_args.model_name_or_path.lower():
model = LlavaMixtralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
elif "mistral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower():
model = LlavaMistralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
elif (
"wizardlm-2" in model_args.model_name_or_path.lower()
or "vicuna" in model_args.model_name_or_path.lower()
or "llama" in model_args.model_name_or_path.lower()
or "yi" in model_args.model_name_or_path.lower()
or "nous-hermes" in model_args.model_name_or_path.lower()
and "wizard-2" in model_args.model_name_or_path.lower()
):
model = LlavaLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
elif "qwen" in model_args.model_name_or_path.lower():
if "moe" in model_args.model_name_or_path.lower() or "A14B" in model_args.model_name_or_path:
model = LlavaQwenMoeForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
deepspeed.utils.set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
else:
model = LlavaQwenForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
elif "gemma" in model_args.model_name_or_path.lower():
model = LlavaGemmaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
else:
raise ValueError(f"Unknown model class {model_args}")
else:
model = transformers.LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
return model
def train(attn_implementation=None):
global local_rank
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if training_args.verbose_logging:
rank0_print(f"Inspecting experiment hyperparameters:\n")
rank0_print(f"model_args = {vars(model_args)}\n\n")
rank0_print(f"data_args = {vars(data_args)}\n\n")
rank0_print(f"training_args = {vars(training_args)}\n\n")
# rank0_print(f"evaluation_args = {vars(evaluation_args)}\n\n")
local_rank = training_args.local_rank
compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(
dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
),
)
)
model = get_model(model_args, training_args, bnb_model_from_pretrained_args)
model.config.use_cache = False
if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None:
model.config.rope_scaling = {
"factor": model_args.rope_scaling_factor,
"type": model_args.rope_scaling_type,
}
if model_args.freeze_backbone:
model.model.requires_grad_(False)
if training_args.bits in [4, 8]:
from peft import prepare_model_for_kbit_training
model.config.torch_dtype = torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
if "mistral" in model_args.model_name_or_path.lower() or "mixtral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="left")
elif "qwen" in model_args.model_name_or_path.lower():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right")
elif (
"wizardlm-2" in model_args.model_name_or_path.lower()
or "vicuna" in model_args.model_name_or_path.lower()
or "llama" in model_args.model_name_or_path.lower()
or "yi" in model_args.model_name_or_path.lower()
or "nous-hermes" in model_args.model_name_or_path.lower()
and "wizard-2" in model_args.model_name_or_path.lower()
):
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
rank0_print(f"Prompt version: {model_args.version}")
if model_args.version == "v0":
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token="[PAD]"),
tokenizer=tokenizer,
model=model,
)
elif model_args.version == "v0.5":
tokenizer.pad_token = tokenizer.unk_token
else:
if tokenizer.unk_token is not None:
tokenizer.pad_token = tokenizer.unk_token
if model_args.version in conversation_lib.conv_templates:
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
if model_args.vision_tower is not None:
model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
vision_tower = model.get_vision_tower()
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
data_args.image_processor = vision_tower.image_processor
data_args.is_multimodal = True
model.config.image_aspect_ratio = data_args.image_aspect_ratio
if data_args.image_grid_pinpoints is not None:
if isinstance(data_args.image_grid_pinpoints, str) and "x" in data_args.image_grid_pinpoints:
try:
patch_size = data_args.image_processor.size[0]
except Exception as e:
patch_size = data_args.image_processor.size["shortest_edge"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", data_args.image_grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
# Multiply all elements by patch_size
data_args.image_grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
elif isinstance(data_args.image_grid_pinpoints, str):
data_args.image_grid_pinpoints = ast.literal_eval(data_args.image_grid_pinpoints)
model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
model.config.image_crop_resolution = data_args.image_crop_resolution
model.config.image_split_resolution = data_args.image_split_resolution
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length
model.config.mm_newline_position = model_args.mm_newline_position
model.config.add_faster_video = model_args.add_faster_video
model.config.faster_token_stride = model_args.faster_token_stride
model.config.add_time_instruction = data_args.add_time_instruction
model.config.force_sample = data_args.force_sample
model.config.mm_spatial_pool_stride = model_args.mm_spatial_pool_stride
### Deciding train which part of the model
if model_args.mm_tunable_parts is None: # traditional way of deciding which part to train
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
model.config.tune_mm_vision_resampler = training_args.tune_mm_vision_resampler = model_args.tune_mm_vision_resampler
if model_args.tune_mm_mlp_adapter or model_args.tune_mm_vision_resampler:
model.requires_grad_(False)
if model_args.tune_mm_mlp_adapter:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
if model_args.tune_mm_vision_resampler:
for p in model.get_model().vision_resampler.parameters():
p.requires_grad = True
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
if training_args.freeze_mm_mlp_adapter:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
model.config.freeze_mm_vision_resampler = training_args.freeze_mm_vision_resampler
if training_args.freeze_mm_vision_resampler:
for p in model.get_model().vision_resampler.parameters():
p.requires_grad = False
model.config.unfreeze_mm_vision_tower = model_args.unfreeze_mm_vision_tower
if model_args.unfreeze_mm_vision_tower:
vision_tower.requires_grad_(True)
else:
vision_tower.requires_grad_(False)
else:
rank0_print(f"Using mm_tunable_parts: {model_args.mm_tunable_parts}")
model.config.mm_tunable_parts = training_args.mm_tunable_parts = model_args.mm_tunable_parts
# Set the entire model to not require gradients by default
model.requires_grad_(False)
vision_tower.requires_grad_(False)
model.get_model().mm_projector.requires_grad_(False)
model.get_model().vision_resampler.requires_grad_(False)
# Parse the mm_tunable_parts to decide which parts to unfreeze
tunable_parts = model_args.mm_tunable_parts.split(",")
if "mm_mlp_adapter" in tunable_parts:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
if "mm_vision_resampler" in tunable_parts:
for p in model.get_model().vision_resampler.parameters():
p.requires_grad = True
if "mm_vision_tower" in tunable_parts:
for name, param in model.named_parameters():
if "vision_tower" in name:
param.requires_grad_(True)
if "mm_language_model" in tunable_parts:
for name, param in model.named_parameters():
if "vision_tower" not in name and "mm_projector" not in name and "vision_resampler" not in name:
param.requires_grad_(True)
total_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters())
trainable_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters() if p.requires_grad)
rank0_print(f"Total parameters: ~{total_params/1e6:.2f} MB)")
rank0_print(f"Trainable parameters: ~{trainable_params/1e6:.2f} MB)")
if training_args.bits in [4, 8]:
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_projector_lr = training_args.mm_projector_lr
model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr
training_args.use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if "norm" in name:
module = module.to(torch.float32)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
if training_args.local_rank == 0 or training_args.local_rank == -1:
if hasattr(model, "config"):
model.config.save_pretrained(training_args.output_dir)
if hasattr(model, "generation_config"):
model.generation_config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_trainables.bin"))
else:
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
rank0_print(f"Model saved to {training_args.output_dir}")
if __name__ == "__main__":
train()
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import deepspeed
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import ast
import yaml
import time
import random
import yaml
import math
import re
import torch
import transformers
import tokenizers
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
from torch.utils.data import Dataset
from llava.train.llava_trainer import LLaVADPOTrainer
from data_processing.utils import load_jsonl, load_json
from llava import conversation as conversation_lib
from llava.model import *
from llava.model.language_model.llava_qwen import LlavaQwenConfig
from llava.model.language_model.llava_llama import LlavaConfig
from llava.model.language_model.llava_mistral import LlavaMistralConfig
from llava.mm_utils import process_highres_image, process_anyres_image, process_highres_image_crop_split, tokenizer_image_token
from llava.utils import rank0_print
from transformers import AutoConfig
import pickle
from trl.trainer.utils import DPODataCollatorWithPadding
from PIL import Image, ImageFile
from decord import VideoReader, cpu
ImageFile.LOAD_TRUNCATED_IMAGES = True
from packaging import version
from typing import Any
local_rank = None
import numpy as np
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse("0.14")
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
model_class_name: Optional[str] = field(default=None, metadata={"help": "Used to init model class, format is XXXXForCausalLM. e.g. currently XXXX is chosen from LlavaLlama, LlavaMixtral, LlavaMistral, Llama"})
mm_tunable_parts: Optional[str] = field(
default=None, metadata={"help": 'Could be "mm_mlp_adapter", "mm_vision_resampler", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_mlp_adapter,mm_language_model"'}
)
# deciding which part of the multimodal model to tune, will overwrite other previous settings
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=False)
tune_mm_mlp_adapter: bool = field(default=False)
tune_mm_vision_resampler: bool = field(default=False)
vision_tower: Optional[str] = field(default=None)
vision_tower_pretrained: Optional[str] = field(default=None) # default to the last layer
unfreeze_mm_vision_tower: bool = field(default=False)
unfreeze_language_model: bool = field(default=False)
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
mm_projector_type: Optional[str] = field(default="linear")
mm_use_im_start_end: bool = field(default=False)
mm_use_im_patch_token: bool = field(default=True)
mm_patch_merge_type: Optional[str] = field(default="flat")
mm_vision_select_feature: Optional[str] = field(default="patch")
mm_resampler_type: Optional[str] = field(default=None)
mm_mask_drop_mode: str = field(default="fixed")
mm_mask_drop_skip_percentage: float = field(default=0.0)
mm_mask_drop_ratio: float = field(default=0.25)
mm_mask_drop_ratio_upper: Optional[float] = field(default=None)
mm_mask_drop_ratio_lower: Optional[float] = field(default=None)
mm_spatial_pool_stride: Optional[int] = field(default=None)
mm_spatial_pool_mode: str = field(default="average")
mm_spatial_pool_out_channels: Optional[int] = field(default=None)
mm_perceiver_depth: Optional[int] = field(default=3)
mm_perceiver_latents: Optional[int] = field(default=32)
mm_perceiver_ff_mult: Optional[float] = field(default=4)
mm_perceiver_pretrained: Optional[str] = field(default=None)
mm_qformer_depth: Optional[int] = field(default=3)
mm_qformer_latents: Optional[int] = field(default=32)
mm_qformer_pretrained: Optional[str] = field(default=None)
rope_scaling_factor: Optional[float] = field(default=None)
rope_scaling_type: Optional[str] = field(default=None)
s2: Optional[bool] = field(default=False)
s2_scales: Optional[str] = field(default="336,672,1008")
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data, in llava's instruction.json format. Supporting multiple json files via /path/to/{a,b,c}.json"})
lazy_preprocess: bool = False
is_multimodal: bool = False
image_folder: Optional[str] = field(default=None)
video_folder: Optional[str] = field(default=None)
video_fps: Optional[int] = field(default=1)
image_aspect_ratio: str = "square"
image_grid_pinpoints: Optional[str] = field(default=None)
image_crop_resolution: int = 384
image_split_resolution: int = 384
input_prompt: Optional[str] = field(default=None)
refine_prompt: Optional[bool] = field(default=False)
frames_upbound: Optional[int] = field(default=0)
num_sample: Optional[int] = field(default=None)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
freeze_mm_mlp_adapter: bool = field(default=False)
freeze_mm_vision_resampler: bool = field(default=False)
mpt_attn_impl: Optional[str] = field(default="triton")
model_max_length: int = field(
default=4096,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
double_quant: bool = field(default=True, metadata={"help": "Compress the quantization statistics through double quantization."})
quant_type: str = field(default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."})
bits: int = field(default=16, metadata={"help": "How many bits to use."})
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
mm_projector_lr: Optional[float] = None
mm_vision_tower_lr: Optional[float] = None
group_by_varlen: bool = field(default=False)
group_by_modality_length: bool = field(default=False)
group_by_modality_length_auto: bool = field(default=False)
auto_find_batch_size: bool = field(default=False)
gradient_checkpointing: bool = field(default=True)
verbose_logging: bool = field(default=False)
attn_implementation: str = field(default="flash_attention_2", metadata={"help": "Use transformers attention implementation."})
dpo_alpha: float = field(default=1.0)
beta: float = field(default=0.1)
gamma: float = field(default=1.0)
generate_during_eval: bool = field(default=False)
precompute_ref_log_probs: bool = field(default=False)
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
if hasattr(trainer.args, "tune_mm_mlp_adapter") and trainer.args.tune_mm_mlp_adapter:
check_only_save_mm_adapter_tunnable = True
# only has mm_mlp_adapter and mm_vision_resampler in the tuneable parts
elif hasattr(trainer.args, "mm_tunable_parts") and (len(trainer.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in trainer.args.mm_tunable_parts or "mm_vision_resampler" in trainer.args.mm_tunable_parts)):
check_only_save_mm_adapter_tunnable = True
else:
check_only_save_mm_adapter_tunnable = False
trainer.accelerator.wait_for_everyone()
torch.cuda.synchronize()
rank0_print(f"Only save projectors: {check_only_save_mm_adapter_tunnable}")
if check_only_save_mm_adapter_tunnable:
# Only save Adapter
keys_to_match = ["mm_projector", "vision_resampler"]
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"))
else:
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
return
if trainer.deepspeed:
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(target, tokenized_lens, speakers):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker == "human":
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def _add_speaker_and_signal(header, source, get_conversation=True):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"
conversation = header
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = conversation_lib.default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = conversation_lib.default_conversation.roles[1]
else:
from_str = "unknown"
sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
if get_conversation:
conversation += sentence["value"]
conversation += BEGIN_SIGNAL
return conversation
def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
if DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN):
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
sentence["value"] = sentence["value"].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>")
replace_token = DEFAULT_IMAGE_TOKEN
if data_args.mm_use_im_start_end:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources
def preprocess_multimodal_movie(sources: Sequence[str], data_args: DataArguments, video_inputs: str) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
if DEFAULT_IMAGE_TOKEN in sentence["value"]:
prompt = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
replace_token = video_inputs
if data_args.mm_use_im_start_end:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources, prompt
def preprocess_llama_2(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
# Mask targets
sep = "[/INST] "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
rank0_print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def make_conv(prompt, answer):
return [
{
"from": "human",
"value": prompt,
},
{
"from": "gpt",
"value": answer,
},
]
def preprocess_gemma(sources: List[List[Dict[str, str]]], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
conv: conversation_lib.Conversation = conversation_lib.default_conversation.copy()
roles: Dict[str, str] = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations: List[str] = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source: List[Dict[str, str]] = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role: str = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids: torch.Tensor = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
else:
input_ids: torch.Tensor = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets: torch.Tensor = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.GEMMA
# Mask target
sep: str = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len: int = int(target.ne(tokenizer.pad_token_id).sum())
rounds: List[str] = conversation.split(conv.sep)
re_rounds = []
for conv_idx in range(0, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2]))
cur_len = 1 # Ignore <bos>
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep # Re-append sep because split on this
# Now "".join(parts)==rou
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 # Ignore <bos>
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 # Ignore <bos>
else:
round_len = len(tokenizer(rou).input_ids) - 1 # Ignore <bos>
instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # Ignore <bos>
round_len += 2 # sep: <end_of_turn>\n takes 2 tokens
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
rank0_print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
im_start, im_end = tokenizer.additional_special_tokens_ids
nl_tokens = tokenizer("\n").input_ids
_system = tokenizer("system").input_ids + nl_tokens
_user = tokenizer("user").input_ids + nl_tokens
_assistant = tokenizer("assistant").input_ids + nl_tokens
# Apply prompt templates
input_ids, targets = [], []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
input_id += system
target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
assert len(input_id) == len(target)
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
if has_image and "<image>" in sentence["value"]:
assert sentence["value"].startswith("<image>"), print(sentence["value"])
_input_id = tokenizer(role).input_ids + nl_tokens + [IMAGE_TOKEN_INDEX] + nl_tokens + tokenizer(sentence["value"][len("<image>") :]).input_ids + [im_end] + nl_tokens
else:
_input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
input_id += _input_id
if role == "<|im_start|>user":
_target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
elif role == "<|im_start|>assistant":
_target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
else:
raise NotImplementedError
target += _target
assert len(input_id) == len(target)
# input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))
# target += [IGNORE_INDEX] * (max_len - len(target))
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(
input_ids=input_ids, # tensor(bs x seq_len)
labels=targets, # tensor(bs x seq_len)
# attention_mask=input_ids.ne(tokenizer.pad_token_id), # tensor(bs x seq_len)
)
def preprocess_llama3(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
max_len=2048,
system_message: str = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
) -> Dict:
roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
nl_tokens = tokenizer("\n").input_ids
# Apply prompt templates
input_ids, targets = [], []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
system = tokenizer("<|begin_of_text|>").input_ids + tokenizer("<|start_header_id|>system<|end_header_id|>").input_ids + nl_tokens * 2 + tokenizer(system_message).input_ids + [eot_id]
input_id += system
target += [IGNORE_INDEX] * len(system)
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
if has_image and "<image>" in sentence["value"]:
assert sentence["value"].startswith("<image>"), print(sentence["value"])
_input_id = tokenizer(role).input_ids + nl_tokens * 2 + [IMAGE_TOKEN_INDEX] + tokenizer(sentence["value"][len("<image>") :]).input_ids + [eot_id]
else:
_input_id = tokenizer(role).input_ids + nl_tokens * 2 + tokenizer(sentence["value"]).input_ids + [eot_id]
input_id += _input_id
if role == "<|start_header_id|>user<|end_header_id|>":
_target = [IGNORE_INDEX] * len(_input_id)
elif role == "<|start_header_id|>assistant<|end_header_id|>":
_target = [IGNORE_INDEX] * (len(tokenizer(role).input_ids) + 2) + _input_id[len(tokenizer(role).input_ids) + 2 : -1] + [eot_id]
else:
raise NotImplementedError
target += _target
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(
input_ids=input_ids, # tensor(bs x seq_len)
labels=targets, # tensor(bs x seq_len)
)
def preprocess_v1(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
round_len -= 1
instruction_len -= 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_mpt(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
# Mask targets
sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
if i != 0 and getattr(tokenizer, "legacy", False) and IS_TOKENIZER_GREATER_THAN_0_14:
round_len += 1
instruction_len += 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f"(#turns={len(re_rounds)} ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
source[0]["value"] = DEFAULT_IMAGE_TOKEN
conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
conversations.append(conversation)
# tokenize conversations
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
return preprocess_llama_2(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version.startswith("v1"):
return preprocess_v1(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "mpt":
return preprocess_mpt(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "qwen":
return preprocess_qwen(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "gemma":
return preprocess_gemma(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "llama_v3":
return preprocess_llama3(sources, tokenizer, has_image=has_image)
# add end signal and concatenate together
conversations = []
for source in sources:
header = f"{conversation_lib.default_conversation.system}\n\n"
conversation = _add_speaker_and_signal(header, source)
conversations.append(conversation)
# tokenize conversations
def get_tokenize_len(prompts):
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
if has_image:
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
else:
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
if has_image:
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
else:
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
speakers = [sentence["from"] for sentence in source]
_mask_targets(target, tokenized_lens, speakers)
return dict(input_ids=input_ids, labels=targets)
def load_data(data_path):
if "jsonl" in data_path:
data_list = load_jsonl(data_path)
else:
data_list = load_json(data_path)
return data_list
class DPODataset(Dataset):
"""Dataset for DPODataset fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
super(DPODataset, self).__init__()
# Handle multiple JSON files specified in the data_path
self.list_data_dict = []
if "{" in data_path and "}" in data_path:
base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups()
file_names = file_pattern.split(",")
rank0_print(f"Loading {file_names} from {base_path}")
data_args.dataset_paths = []
for file_name in file_names:
data_args.dataset_paths.append(f"{base_path}{file_name}.json")
full_path = f"{base_path}{file_name}.json"
rank0_print(f"Loading {full_path}")
cur_data_dict = load_data(full_path)
rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}")
self.list_data_dict.extend(cur_data_dict)
elif data_path.endswith(".yaml"):
with open(data_path, "r") as file:
yaml_data = yaml.safe_load(file)
datasets = yaml_data.get("datasets")
# file should be in the format of:
# datasets:
# - json_path: xxxx1.json
# sampling_strategy: first:1000
# - json_path: xxxx2.json
# sampling_strategy: end:3000
# - json_path: xxxx3.json
# sampling_strategy: random:999
data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets]
for dataset in datasets:
json_path = dataset.get("json_path")
sampling_strategy = dataset.get("sampling_strategy", "all")
sampling_number = None
rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy")
cur_data_dict = load_data(json_path)
if ":" in sampling_strategy:
sampling_strategy, sampling_number = sampling_strategy.split(":")
if "%" in sampling_number:
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
else:
sampling_number = int(sampling_number)
# Apply the sampling strategy
if sampling_strategy == "first" and sampling_number is not None:
cur_data_dict = cur_data_dict[:sampling_number]
elif sampling_strategy == "end" and sampling_number is not None:
cur_data_dict = cur_data_dict[-sampling_number:]
elif sampling_strategy == "random" and sampling_number is not None:
random.shuffle(cur_data_dict)
cur_data_dict = cur_data_dict[:sampling_number]
rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
self.list_data_dict.extend(cur_data_dict)
else:
data_args.dataset_paths = [data_path]
rank0_print(f"Loading {data_path}")
cur_data_dict = load_data(data_path)
rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}")
self.list_data_dict.extend(cur_data_dict)
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
# Calculate the length of the prompt, answer, chosen, and rejected text
cur_len = len(sample["prompt"].split()) + len(sample["answer"].split()) + len(sample["chosen"].split()) + len(sample["rejected"].split())
# Add additional tokens if an image is present
img_tokens = 128 if "image" in sample else 0
length_list.append(cur_len + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
# Calculate the length of the prompt, answer, chosen, and rejected text
cur_len = len(sample["prompt"].split()) + len(sample["answer"].split()) + len(sample["chosen"].split()) + len(sample["rejected"].split())
# If the sample includes a video, the length is positive; otherwise, it is negative
cur_len = cur_len if ("video" in sample or "image" in sample) else -cur_len
length_list.append(cur_len)
return length_list
def process_image(self, image_file):
image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
# print(f"\n\nInspecting the image path, folder = {image_folder}, image={image_file}\n\n")
try:
image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
except Exception as exn:
print(f"Failed to open image {image_file}. Exception:", exn)
raise exn
image_size = image.size
if self.data_args.image_aspect_ratio == "highres":
image = process_highres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
elif self.data_args.image_aspect_ratio == "anyres" or "anyres" in self.data_args.image_aspect_ratio:
image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
elif self.data_args.image_aspect_ratio == "crop_split":
image = process_highres_image_crop_split(image, self.data_args)
elif self.data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
else:
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
return image, image_size, "image"
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# TODO: define number of retries somewhere else
num_base_retries = 3
num_final_retries = 300
# try the current sample first
for attempt_idx in range(num_base_retries):
try:
sample = self._get_item(i)
return sample
except Exception as e:
# sleep 1s in case it is a cloud disk issue
print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
time.sleep(1)
# try other samples, in case it is file corruption issue
for attempt_idx in range(num_base_retries):
try:
next_index = min(i + 1, len(self.list_data_dict) - 1)
# sample_idx = random.choice(range(len(self)))
sample = self._get_item(next_index)
return sample
except Exception as e:
# no need to sleep
print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e)
pass
# still fail, most likely to be path issue or cloud disk issue, retry the same sample for longer
# for attempt_idx in range(num_final_retries):
# try:
# sample = self._get_item(i)
# return sample
# except Exception as e:
# # sleep 1s in case it is a cloud disk issue
# print(f"[Final try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
# time.sleep(1)
# Finally raise exception on failing.
assert False, "Failed to fetch sample."
def _get_item(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
suffix = None
if "image" in sources[0]:
image_file = self.list_data_dict[i]["image"]
if type(image_file) is list:
image = [self.process_image(f) for f in image_file]
else:
image = [self.process_image(image_file)]
# sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
elif "video" in sources[0]: # FIXME: This logic should be largely improved by Yuanhan. It's too messy now.
video_file = self.list_data_dict[i]["video"]
video_folder = self.data_args.video_folder
video_file = os.path.join(video_folder, video_file)
suffix = video_file.split(".")[-1]
if not os.path.exists(video_file):
print("File {} not exist!".format(video_file))
if suffix == "pkl":
video_info = pickle.load(open(video_file, "rb"))
image = torch.from_numpy(video_info["feats"][:, 1:])
input_prompt = video_info["inputs"].replace("...", "")
# replace the default image token with multiple tokens
input_prompt = input_prompt.replace(DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN * self.data_args.video_token)
sources, query_prompt = preprocess_multimodal_movie(copy.deepcopy([e["conversations"] for e in sources]), self.data_args, input_prompt)
else: # using videoreader
if "shareVideoGPTV" not in video_file and "liangke" not in video_file:
vr = VideoReader(video_file, ctx=cpu(0))
total_frame_num = len(vr)
avg_fps = round(vr.get_avg_fps() / self.data_args.video_fps)
frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
if self.data_args.frames_upbound > 0:
if len(frame_idx) > self.data_args.frames_upbound:
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.data_args.frames_upbound, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
video = vr.get_batch(frame_idx).asnumpy()
video = np.array(video)
else:
if "liangke" in video_file:
video_file = self.list_data_dict[i]["video"]
frame_files = [os.path.join(video_file, f) for f in os.listdir(video_file) if os.path.isfile(os.path.join(video_file, f))]
frame_files.sort() # Ensure the frames are sorted if they are named sequentially
# TODO: Hard CODE: Determine the indices for uniformly sampling 10 frames
num_frames_to_sample = 10
total_frames = len(frame_files)
sampled_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int)
# Read and store the sampled frames
video = []
for idx in sampled_indices:
frame_path = frame_files[idx]
try:
with Image.open(frame_path) as img:
frame = img.convert("RGB")
video.append(frame)
except IOError:
print(f"Failed to read frame at path: {frame_path}")
processor = self.data_args.image_processor
image = processor.preprocess(video, return_tensors="pt")["pixel_values"]
image = [(image, video[0].size, "video")]
# sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
has_image = ("image" in self.list_data_dict[i]) or ("video" in self.list_data_dict[i])
# data_dict = preprocess(sources, self.tokenizer, has_image=has_image)
data_dict = copy.deepcopy(self.list_data_dict[i]) # inplace modification following
if "prompt" in data_dict:
prompt = data_dict["prompt"]
prompt = prompt.replace("<image>", "").strip()
prompt = "<image>\n" + prompt
data_dict["prompt"] = prompt
else:
prompt = None
if suffix == "pkl":
prompt = [query_prompt]
# image exist in the data
if "image" in self.list_data_dict[i]:
data_dict["image"] = image
elif "video" in self.list_data_dict[i]:
data_dict["image"] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict["image"] = [
(torch.zeros(1, 3, crop_size["height"], crop_size["width"]), (crop_size["width"], crop_size["height"]), "text"),
]
# prompt exist in the data
data_dict["has_image"] = has_image
return data_dict
@dataclass
class DPODataCollator(DPODataCollatorWithPadding):
"""Collate examples for DPO fine-tuning."""
# tokenizer: transformers.PreTrainedTokenizer
def collate(self, batch):
# first, pad everything to the same length
# input_ids, labels = tuple([instance[key] for instance in instances]
# for key in ("input_ids", "labels"))
# input_ids = torch.nn.utils.rnn.pad_sequence(
# input_ids,
# batch_first=True,
# padding_value=self.tokenizer.pad_token_id)
# labels = torch.nn.utils.rnn.pad_sequence(labels,
# batch_first=True,
# padding_value=IGNORE_INDEX)
# input_ids = input_ids[:, :self.tokenizer.model_max_length]
# labels = labels[:, :self.tokenizer.model_max_length]
# batch = dict(
# input_ids=input_ids,
# labels=labels,
# attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
# )
padded_batch = {}
for k in batch[0].keys():
if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
# if "prompt" in k:
# to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
# else:
to_pad = [torch.LongTensor(ex[k]) for ex in batch]
if k.endswith("_input_ids"):
padding_value = self.tokenizer.pad_token_id
elif k.endswith("_labels"):
padding_value = self.label_pad_token_id
else:
continue
# elif k.endswith("_attention_mask"):
# padding_value = self.padding_value
# else:
# raise ValueError(f"Unexpected key in batch '{k}'")
padded_batch[k] = torch.nn.utils.rnn.pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
# for the prompt, flip back so padding is on left side
# if "prompt" in k:
# padded_batch[k] = padded_batch[k].flip(dims=[1])
else:
padded_batch[k] = [ex[k] for ex in batch]
for k in ["chosen_input_ids", "rejected_input_ids"]:
attn_k = k.replace("input_ids", "attention_mask")
padded_batch[attn_k] = padded_batch[k].ne(self.tokenizer.pad_token_id)
return padded_batch
def tokenize_batch_element(self, prompt: str, chosen: str, rejected: str, has_image: bool = True) -> Dict:
"""Tokenize a single batch element.
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
in case the prompt + chosen or prompt + rejected responses is/are too long. First
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
We also create the labels for the chosen/rejected responses, which are of length equal to
the sum of the length of the prompt and the chosen/rejected response, with
label_pad_token_id for the prompt tokens.
"""
# import pdb; pdb.set_trace()
batch = {}
chosen_sources = make_conv(prompt, chosen)
rejected_sources = make_conv(prompt, rejected)
chosen_data_dict = preprocess([chosen_sources], self.tokenizer, has_image=has_image)
# chosen_data_dict['attention_mask'] = chosen_data_dict["input_ids"].ne(self.tokenizer.pad_token_id)
rejected_data_dict = preprocess([rejected_sources], self.tokenizer, has_image=has_image)
# rejected_data_dict['attention_mask'] = rejected_data_dict["input_ids"].ne(self.tokenizer.pad_token_id)
chosen_data_dict = {k: v[0] for k, v in chosen_data_dict.items()}
rejected_data_dict = {k: v[0] for k, v in rejected_data_dict.items()}
for k, toks in {
"chosen": chosen_data_dict,
"rejected": rejected_data_dict,
}.items():
for type_key, tokens in toks.items():
if type_key == "token_type_ids":
continue
batch[f"{k}_{type_key}"] = tokens
return batch
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
tokenized_batch = []
Xs, keys = [], []
for feature in features:
prompt = feature["prompt"]
chosen = feature["chosen"]
rejected = feature["rejected"]
has_image = feature["has_image"]
# Xs.append(feature[has_X])
# keys.append(has_X)
batch_element = self.tokenize_batch_element(prompt, chosen, rejected, has_image=has_image)
tokenized_batch.append(batch_element)
# return collated batch
padded_batch = self.collate(tokenized_batch)
# import pdb;pdb.set_trace()
if "image" in features[0]:
# instances[1]['image'][0][0].shape
# torch.Size([5, 3, 224, 224])
images = [instance["image"] for instance in features]
padded_batch["image_sizes"] = [im[1] for im_list in images for im in im_list]
padded_batch["modalities"] = [im[2] for im_list in images for im in im_list]
images = [im[0] for im_list in images for im in im_list]
# import pdb;pdb.set_trace()
padded_batch["images"] = images
# padded_batch["images"] =[padded_batch["modalities"], images]
return padded_batch
def make_dpo_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = DPODataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)
return train_dataset
def get_model(model_args, training_args, bnb_model_from_pretrained_args):
assert training_args.attn_implementation
if training_args.attn_implementation == "sdpa" and torch.__version__ < "2.1.2":
raise ValueError("The 'sdpa' attention implementation requires torch version 2.1.2 or higher.")
######################### Overwrite config #########################
customized_kwargs = dict()
customized_kwargs.update(bnb_model_from_pretrained_args)
overwrite_config = {}
cfg_pretrained = None
if "qwen" in model_args.model_name_or_path.lower():
cfg_pretrained = LlavaQwenConfig.from_pretrained(model_args.model_name_or_path)
elif "mistral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower():
cfg_pretrained = LlavaMistralConfig.from_pretrained(model_args.model_name_or_path)
elif (
"wizardlm-2" in model_args.model_name_or_path.lower()
or "vicuna" in model_args.model_name_or_path.lower()
or "llama" in model_args.model_name_or_path.lower()
or "yi" in model_args.model_name_or_path.lower()
or "nous-hermes" in model_args.model_name_or_path.lower()
and "wizard-2" in model_args.model_name_or_path.lower()
):
cfg_pretrained = LlavaConfig.from_pretrained(model_args.model_name_or_path)
else:
cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path)
if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None and cfg_pretrained is not None:
overwrite_config["rope_scaling"] = {
"factor": model_args.rope_scaling_factor,
"type": model_args.rope_scaling_type,
}
if training_args.model_max_length is None:
training_args.model_max_length = cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor
overwrite_config["max_sequence_length"] = training_args.model_max_length
assert training_args.model_max_length == int(cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor), print(
f"model_max_length: {training_args.model_max_length}, max_position_embeddings: {cfg_pretrained.max_position_embeddings}, rope_scaling_factor: {model_args.rope_scaling_factor}"
)
# overwrite_config["max_sequence_length"] = model_args.max_sequence_length
# overwrite_config["tokenizer_model_max_length"] = model_args.tokenizer_model_max_length
if model_args.mm_spatial_pool_stride is not None and model_args.mm_spatial_pool_out_channels is not None and model_args.mm_spatial_pool_mode is not None and model_args.mm_resampler_type is not None and cfg_pretrained is not None:
overwrite_config["mm_resampler_type"] = model_args.mm_resampler_type
overwrite_config["mm_spatial_pool_stride"] = model_args.mm_spatial_pool_stride
overwrite_config["mm_spatial_pool_out_channels"] = model_args.mm_spatial_pool_out_channels
overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode
if overwrite_config:
rank0_print(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(cfg_pretrained, k, v)
customized_kwargs["config"] = cfg_pretrained
######################### Finish Overwrite ###########################
ref_model = None
if model_args.model_class_name is not None:
actual_model_class_name = f"{model_args.model_class_name}ForCausalLM"
model_class = getattr(transformers, actual_model_class_name)
rank0_print(f"Using model class {model_class} from {model_args.model_class_name}")
model = model_class.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
elif model_args.vision_tower is not None:
if "mixtral" in model_args.model_name_or_path.lower():
model = LlavaMixtralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
elif "mistral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower():
model = LlavaMistralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
elif (
"wizardlm-2" in model_args.model_name_or_path.lower()
or "vicuna" in model_args.model_name_or_path.lower()
or "llama" in model_args.model_name_or_path.lower()
or "yi" in model_args.model_name_or_path.lower()
or "nous-hermes" in model_args.model_name_or_path.lower()
and "wizard-2" in model_args.model_name_or_path.lower()
):
model = LlavaLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
if "zero3" in training_args.deepspeed:
rank0_print("#### Initialize reference model #####")
ref_model = LlavaLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
elif "qwen" in model_args.model_name_or_path.lower() or "quyen" in model_args.model_name_or_path.lower():
if "moe" in model_args.model_name_or_path.lower():
model = LlavaQwenMoeForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
deepspeed.utils.set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
else:
model = LlavaQwenForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
if "zero3" in training_args.deepspeed:
rank0_print("#### Initialize reference model #####")
ref_model = LlavaQwenForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
elif "gemma" in model_args.model_name_or_path.lower():
model = LlavaGemmaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs,
)
else:
raise ValueError(f"Unknown model class {model_args}")
else:
model = transformers.LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=training_args.attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **customized_kwargs
)
return model, ref_model
def train(attn_implementation=None):
global local_rank
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if training_args.verbose_logging:
rank0_print(f"Inspecting experiment hyperparameters:\n")
rank0_print(f"model_args = {vars(model_args)}\n\n")
rank0_print(f"data_args = {vars(data_args)}\n\n")
rank0_print(f"training_args = {vars(training_args)}\n\n")
# rank0_print(f"evaluation_args = {vars(evaluation_args)}\n\n")
local_rank = training_args.local_rank
compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(
dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
),
)
)
model, ref_model = get_model(model_args, training_args, bnb_model_from_pretrained_args)
model.config.use_cache = False
if model_args.freeze_backbone:
model.model.requires_grad_(False)
if training_args.bits in [4, 8]:
from peft import prepare_model_for_kbit_training
model.config.torch_dtype = torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if ref_model is not None:
ref_model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if ref_model is not None:
ref_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
if "mpt" in model_args.model_name_or_path:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right")
elif "mistral" in model_args.model_name_or_path.lower() or "mixtral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="left")
elif "qwen" in model_args.model_name_or_path.lower():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right")
else: # for all other models
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
rank0_print(f"Prompt version: {model_args.version}")
if model_args.version == "v0":
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token="[PAD]"),
tokenizer=tokenizer,
model=model,
)
elif model_args.version == "v0.5":
tokenizer.pad_token = tokenizer.unk_token
else:
if tokenizer.unk_token is not None:
tokenizer.pad_token = tokenizer.unk_token
if model_args.version in conversation_lib.conv_templates:
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
if model_args.vision_tower is not None:
model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
vision_tower = model.get_vision_tower()
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
data_args.image_processor = vision_tower.image_processor
data_args.is_multimodal = True
model.config.image_aspect_ratio = data_args.image_aspect_ratio
if data_args.image_grid_pinpoints is not None:
# for input like "(1x1)...(3x3)", convert to [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (3, 2), (1, 3), (2, 3), (3, 3)]
if "x" in data_args.image_grid_pinpoints and "..." in data_args.image_grid_pinpoints:
vis_encoder_size = data_args.image_processor.size[0]
matches = re.findall(r"\((\d+)x(\d+)\)", data_args.image_grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
grid_pinpoints = [[dim * vis_encoder_size for dim in pair] for pair in grid_pinpoints]
data_args.image_grid_pinpoints = grid_pinpoints
elif "x" in data_args.image_grid_pinpoints:
vis_encoder_size = data_args.image_processor.size[0]
assert vis_encoder_size in [224, 336, 384, 448, 512], "vis_encoder_size should be in [224, 336, 384, 448, 512]"
grid_pinpoints = data_args.image_grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(")
data_args.image_grid_pinpoints = [[int(x) * vis_encoder_size for x in item.split(",")] for item in grid_pinpoints]
else:
data_args.image_grid_pinpoints = ast.literal_eval(data_args.image_grid_pinpoints) # for backward compatibility
model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
model.config.image_crop_resolution = data_args.image_crop_resolution
model.config.image_split_resolution = data_args.image_split_resolution
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length
### Deciding train which part of the model
if model_args.mm_tunable_parts is None: # traditional way of deciding which part to train
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
model.config.tune_mm_vision_resampler = training_args.tune_mm_vision_resampler = model_args.tune_mm_vision_resampler
if model_args.tune_mm_mlp_adapter or model_args.tune_mm_vision_resampler:
model.requires_grad_(False)
if model_args.tune_mm_mlp_adapter:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
if model_args.tune_mm_vision_resampler:
for p in model.get_model().vision_resampler.parameters():
p.requires_grad = True
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
if training_args.freeze_mm_mlp_adapter:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
model.config.freeze_mm_vision_resampler = training_args.freeze_mm_vision_resampler
if training_args.freeze_mm_vision_resampler:
for p in model.get_model().vision_resampler.parameters():
p.requires_grad = False
model.config.unfreeze_mm_vision_tower = model_args.unfreeze_mm_vision_tower
if model_args.unfreeze_mm_vision_tower:
vision_tower.requires_grad_(True)
else:
vision_tower.requires_grad_(False)
else:
rank0_print(f"Using mm_tunable_parts: {model_args.mm_tunable_parts}")
model.config.mm_tunable_parts = training_args.mm_tunable_parts = model_args.mm_tunable_parts
# Set the entire model to not require gradients by default
model.requires_grad_(False)
vision_tower.requires_grad_(False)
model.get_model().mm_projector.requires_grad_(False)
model.get_model().vision_resampler.requires_grad_(False)
# Parse the mm_tunable_parts to decide which parts to unfreeze
tunable_parts = model_args.mm_tunable_parts.split(",")
if "mm_mlp_adapter" in tunable_parts:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
if "mm_vision_resampler" in tunable_parts:
for p in model.get_model().vision_resampler.parameters():
p.requires_grad = True
if "mm_vision_tower" in tunable_parts:
for name, param in model.named_parameters():
if "vision_tower" in name:
param.requires_grad_(True)
if "mm_language_model" in tunable_parts:
for name, param in model.named_parameters():
if "vision_tower" not in name and "mm_projector" not in name and "vision_resampler" not in name:
param.requires_grad_(True)
total_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters())
trainable_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters() if p.requires_grad)
rank0_print(f"Total parameters: ~{total_params/1e6:.2f} MB)")
rank0_print(f"Trainable parameters: ~{trainable_params/1e6:.2f} MB)")
if training_args.bits in [4, 8]:
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_projector_lr = training_args.mm_projector_lr
model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr
training_args.use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
if ref_model is not None:
ref_model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
ref_vision_tower = ref_model.get_vision_tower()
ref_vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
ref_model.config.image_aspect_ratio = data_args.image_aspect_ratio
ref_model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
ref_model.config.image_crop_resolution = data_args.image_crop_resolution
ref_model.config.image_split_resolution = data_args.image_split_resolution
ref_model.config.tokenizer_padding_side = tokenizer.padding_side
ref_model.config.tokenizer_model_max_length = tokenizer.model_max_length
ref_model.config.mm_use_im_start_end = data_args.mm_use_im_start_end
ref_model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
ref_model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
parameter_names = [n for n, _ in ref_model.named_parameters()]
for param_name in parameter_names:
param = ref_model.get_parameter(param_name)
param.requires_grad = False
ref_model.eval()
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if "norm" in name:
module = module.to(torch.float32)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
train_dataset = make_dpo_data_module(tokenizer=tokenizer, data_args=data_args)
data_collator = DPODataCollator(
tokenizer,
label_pad_token_id=IGNORE_INDEX,
pad_token_id=tokenizer.pad_token_id,
)
trainer = LLaVADPOTrainer(
model,
ref_model,
args=training_args,
dpo_alpha=training_args.dpo_alpha,
beta=training_args.beta,
gamma=training_args.gamma,
train_dataset=train_dataset,
eval_dataset=None,
data_collator=data_collator,
tokenizer=tokenizer,
max_length=training_args.model_max_length,
generate_during_eval=False, # training_args.generate_during_eval,
precompute_ref_log_probs=training_args.precompute_ref_log_probs,
)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
if training_args.local_rank == 0 or training_args.local_rank == -1:
if hasattr(model, "config"):
model.config.save_pretrained(training_args.output_dir)
if hasattr(model, "generation_config"):
model.generation_config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_trainables.bin"))
else:
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
rank0_print(f"Model saved to {training_args.output_dir}")
if __name__ == "__main__":
train()
from llava.train.train import train
if __name__ == "__main__":
train()
import datetime
import logging
import logging.handlers
import os
import sys
import numpy as np
import requests
from llava.constants import LOGDIR
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "I am sorry. Your input may violate our content moderation guidelines. Please avoid using harmful or offensive content."
handler = None
import torch.distributed as dist
try:
import av
from decord import VideoReader, cpu
except ImportError:
print("Please install pyav to use video processing functions.")
def process_video_with_decord(video_file, data_args):
vr = VideoReader(video_file, ctx=cpu(0), num_threads=1)
total_frame_num = len(vr)
video_time = total_frame_num / vr.get_avg_fps()
avg_fps = round(vr.get_avg_fps() / data_args.video_fps)
frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
frame_time = [i/avg_fps for i in frame_idx]
if data_args.frames_upbound > 0:
if len(frame_idx) > data_args.frames_upbound or data_args.force_sample:
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frame_time = [i/vr.get_avg_fps() for i in frame_idx]
video = vr.get_batch(frame_idx).asnumpy()
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
num_frames_to_sample = num_frames = len(frame_idx)
# https://github.com/dmlc/decord/issues/208
vr.seek(0)
return video, video_time, frame_time, num_frames_to_sample
def process_video_with_pyav(video_file, data_args):
container = av.open(video_file)
# !!! This is the only difference. Using auto threading
container.streams.video[0].thread_type = "AUTO"
video_frames = []
for packet in container.demux():
if packet.stream.type == 'video':
for frame in packet.decode():
video_frames.append(frame)
total_frame_num = len(video_frames)
video_time = video_frames[-1].time
avg_fps = round(total_frame_num / video_time / data_args.video_fps)
frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
if data_args.frames_upbound > 0:
if len(frame_idx) > data_args.frames_upbound:
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frames = [video_frames[i] for i in frame_idx]
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
def rank0_print(*args):
if dist.is_initialized():
if dist.get_rank() == 0:
print(f"Rank {dist.get_rank()}: ", *args)
else:
print(*args)
def rank_print(*args):
if dist.is_initialized():
print(f"Rank {dist.get_rank()}: ", *args)
else:
print(*args)
def build_logger(logger_name, logger_filename):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True)
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ""
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ""
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == "\n":
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != "":
self.logger.log(self.log_level, self.linebuf.rstrip())
self.linebuf = ""
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def violates_moderation(text):
"""
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException as e:
print(f"######################### Moderation Error: {e} #########################")
flagged = False
except KeyError as e:
print(f"######################### Moderation Error: {e} #########################")
flagged = False
return flagged
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
# 模型唯一标识
modelCode=1475
# 模型名称
modelName=LLaVA-NeXT_pytorch
# 模型描述
modelDescription=高性能图像及视频问答
# 应用场景
appScenario=推理,对话问答,电商,教育,交通,能源
# 框架类型
frameType=Pytorch
# LLaVA-OneVision: Easy Visual Task Transfer
## 论文
`LLaVA-OneVision: Easy Visual Task Transfer`
* https://arxiv.org/pdf/2408.03326
## 模型结构
该模型由三个部分组成,分别是LLM(Qwen-2),Vision Encoder(SigLIP)和Projector(两层MLP)。
![alt text](readme_imgs/arch.png)
## 算法原理
该算法的主要原理在于AnyRes表示和任务迁移机制,使得单一模型能够覆盖多样化的视觉场景。
![alt text](readme_imgs/alg.png)
## 环境配置
参考[README.md](../README.md)
## 数据集
## 训练
## 推理
### 原生
单张图片输入
```bash
python single_image.py
```
文本图像交错输入
```bash
python image-text.py
```
视频输入
```bash
python video.py
```
注意:在运行前需要修改文件中的参数。
## result
![alt text](readme_imgs/result.png)
### 精度
## 应用场景
参考[README.md](../README.md)
## 预训练权重
|model|url|
|:---:|:---:|
|lmms-lab/llava-onevision-qwen2-7b-ov| [hf](https://hf-mirror.com/lmms-lab/llava-onevision-qwen2-7b-ov) \| [SCNet]() |
|lmms-lab/llava-onevision-qwen2-0.5b-ov| [hf](https://hf-mirror.com/lmms-lab/llava-onevision-qwen2-0.5b-ov) \| [SCNet]() |
|lmms-lab/llava-onevision-qwen2-0.5b-si| [hf](https://hf-mirror.com/lmms-lab/llava-onevision-qwen2-0.5b-si) \| [SCNet]() |
|lmms-lab/llava-onevision-qwen2-7b-si| [hf](https://hf-mirror.com/lmms-lab/llava-onevision-qwen2-7b-si) \| [SCNet]() |
|lmms-lab/llava-onevision-qwen2-7b-ov-chat| [hf](https://hf-mirror.com/lmms-lab/llava-onevision-qwen2-7b-ov-chat) \| [SCNet]() |
|lmms-lab/llava-onevision-qwen2-72b-ov-chat| [hf](https://hf-mirror.com/lmms-lab/llava-onevision-qwen2-72b-ov-chat) \| [SCNet]() |
## 源码仓库及问题反馈
* 参考[README.md](../README.md)
## 参考资料
* https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_OneVision.md
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment