Commit 07dbc76b authored by dongchy920's avatar dongchy920
Browse files

MiniGemini_pytorch

parents
import argparse
import datetime
import json
import os
import time
import gradio as gr
import requests
from mgm.conversation import (default_conversation, conv_templates,
SeparatorStyle)
from mgm.constants import LOGDIR
from mgm.utils import (build_logger, server_error_msg,
violates_moderation, moderation_msg)
import hashlib
logger = build_logger("gradio_web_server", "gradio_web_server.log")
headers = {"User-Agent": "MGM Client"}
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True)
disable_btn = gr.Button(interactive=False)
priority = {
"vicuna-13b": "aaaaaaa",
"koala-13b": "aaaaaab",
}
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def get_model_list():
ret = requests.post(args.controller_url + "/refresh_all_workers")
assert ret.status_code == 200
ret = requests.post(args.controller_url + "/list_models")
models = ret.json()["models"]
models.sort(key=lambda x: priority.get(x, x))
logger.info(f"Models: {models}")
return models
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
dropdown_update = gr.Dropdown(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown(value=model, visible=True)
state = default_conversation.copy()
return state, dropdown_update
def load_demo_refresh_model_list(request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}")
models = get_model_list()
state = default_conversation.copy()
dropdown_update = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else ""
)
return state, dropdown_update
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
def upvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"upvote. ip: {request.client.host}")
vote_last_response(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"downvote. ip: {request.client.host}")
vote_last_response(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response(state, model_selector, request: gr.Request):
logger.info(f"flag. ip: {request.client.host}")
vote_last_response(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
def regenerate(state, image_process_mode, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
prev_human_msg = state.messages[-2]
if type(prev_human_msg[1]) in (tuple, list):
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = default_conversation.copy()
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def add_text(state, text, image, image_process_mode, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0 and image is None:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
if args.moderate:
flagged = violates_moderation(text)
if flagged:
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
no_change_btn,) * 5
text = text[:1536] # Hard cut-off
if image is not None:
text = text[:1200] # Hard cut-off for images
if '<image>' not in text:
# text = '<Image><image></Image>' + text
text = text + '\n<image>'
text = (text, image, image_process_mode)
if len(state.get_images(return_pil=True)) > 0:
state = default_conversation.copy()
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, gen_image, use_ocr, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
# First round of conversation
if "mgm" in model_name.lower():
if '8x7b' in model_name.lower():
template_name = "mistral_instruct"
elif '34b' in model_name.lower():
template_name = "chatml_direct"
elif '2b' in model_name.lower():
template_name = "gemma"
else:
template_name = "vicuna_v1"
else:
template_name = "vicuna_v1"
new_state = conv_templates[template_name].copy()
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
# Query worker address
controller_url = args.controller_url
ret = requests.post(controller_url + "/get_worker_address",
json={"model": model_name})
worker_addr = ret.json()["address"]
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
# No available worker
if worker_addr == "":
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
# Construct prompt
prompt = state.get_prompt()
all_images = state.get_images(return_pil=True)
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
for image, hash in zip(all_images, all_image_hash):
t = datetime.datetime.now()
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
if not os.path.isfile(filename):
os.makedirs(os.path.dirname(filename), exist_ok=True)
image.save(filename)
# Generate Image
if 'generate' in prompt.lower():
gen_image = 'Yes'
elif 'show me one idea of what i could make with this?' in prompt.lower() and len(all_images) == 1:
h, w = all_images[0].size
if h == 922 and w == 672:
gen_image = 'Yes'
# Make requests
pload = {
"model": model_name,
"prompt": prompt,
"temperature": float(temperature),
"top_p": float(top_p),
"max_new_tokens": min(int(max_new_tokens), 1536),
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
"images": f'List of {len(state.get_images())} images: {all_image_hash}',
"gen_image": bool(gen_image == 'Yes'),
"use_ocr": bool(use_ocr == 'Yes'),
}
logger.info(f"==== request ====\n{pload}")
pload['images'] = state.get_images()
state.messages[-1][-1] = "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
# Stream output
response = requests.post(worker_addr + "/worker_generate_stream",
headers=headers, json=pload, stream=True, timeout=30)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
if 'image' not in data.keys():
output = data["text"][len(prompt):].strip()
state.messages[-1][-1] = output + "▌"
else:
output = (data["text"][len(prompt):].strip(), data["image"])
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
time.sleep(0.03)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
if type(state.messages[-1][-1]) is not tuple:
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"images": all_image_hash,
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
title_markdown = ("""
# Mini-Gemini: Mining the Potential of Multi-modality Vision Language Models
""")
tos_markdown = ("""
### Terms of use
By using this service, users are required to agree to the following terms:
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
""")
learn_more_markdown = ("""
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
""")
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
with gr.Blocks(title="MGM", theme=gr.themes.Default(), css=block_css) as demo:
state = gr.State()
if not embed_mode:
gr.Markdown(title_markdown)
with gr.Row():
with gr.Column(scale=3):
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else "",
interactive=True,
show_label=False,
container=False)
imagebox = gr.Image(type="pil")
image_process_mode = gr.Radio(
["Crop", "Resize", "Pad", "Default"],
value="Default",
label="Preprocess for non-square image", visible=False)
if cur_dir is None:
cur_dir = os.path.dirname(os.path.abspath(__file__))
gr.Examples(examples=[
[f"{cur_dir}/examples/monday.jpg", "Explain why this meme is funny, and generate a picture when the weekend coming."],
[f"{cur_dir}/examples/woolen.png", "Show me one idea of what I could make with this?"],
[f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
[f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
], inputs=[imagebox, textbox])
with gr.Accordion("Function", open=True) as parameter_row:
gen_image = gr.Radio(choices=['Yes', 'No'], value='No', interactive=True, label="Generate Image")
use_ocr = gr.Radio(choices=['Yes', 'No'], value='Yes', interactive=True, label="Use OCR")
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
with gr.Column(scale=7):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="MGM Chatbot",
height=940,
layout="panel",
)
with gr.Row():
with gr.Column(scale=7):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(value="Send", variant="primary")
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
if not embed_mode:
# gr.Markdown(function_markdown)
gr.Markdown(tos_markdown)
gr.Markdown(learn_more_markdown)
url_params = gr.JSON(visible=False)
# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(
upvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
downvote_btn.click(
downvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
flag_btn.click(
flag_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
regenerate_btn.click(
regenerate,
[state, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list
).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
clear_btn.click(
clear_history,
None,
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
)
textbox.submit(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
submit_btn.click(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list
).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
if args.model_list_mode == "once":
demo.load(
load_demo,
[url_params],
[state, model_selector],
_js=get_window_url_params
)
elif args.model_list_mode == "reload":
demo.load(
load_demo_refresh_model_list,
None,
[state, model_selector],
queue=False
)
else:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
parser.add_argument("--concurrency-count", type=int, default=16)
parser.add_argument("--model-list-mode", type=str, default="once",
choices=["once", "reload"])
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
models = get_model_list()
logger.info(args)
demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
demo.queue(
api_open=False
).launch(
server_name=args.host,
server_port=args.port,
share=args.share
)
\ No newline at end of file
"""
A model worker executes the model.
"""
import argparse
import asyncio
import json
import time
import threading
import uuid
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
import requests
import torch
import uvicorn
from functools import partial
from mgm.constants import WORKER_HEART_BEAT_INTERVAL
from mgm.utils import (build_logger, server_error_msg,
pretty_print_semaphore)
from mgm.model.builder import load_pretrained_model
from mgm.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
from mgm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from transformers import TextIteratorStreamer
from threading import Thread
try:
from diffusers import StableDiffusionXLPipeline
except:
print('please install diffusers==0.26.3')
try:
from paddleocr import PaddleOCR
except:
print('please install paddleocr following https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/README_en.md')
import io
import base64
GB = 1 << 30
worker_id = str(uuid.uuid4())[:6]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
global_counter = 0
model_semaphore = None
def heart_beat_worker(controller):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
controller.send_heart_beat()
class ModelWorker:
def __init__(self, controller_addr, worker_addr,
worker_id, no_register,
model_path, model_base, model_name,
load_8bit, load_4bit, device, use_flash_attn=False):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
if model_path.endswith("/"):
model_path = model_path[:-1]
if model_name is None:
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
self.model_name = model_paths[-2] + "_" + model_paths[-1]
else:
self.model_name = model_paths[-1]
else:
self.model_name = model_name
self.device = device
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn)
# self.is_multimodal = 'llava' in self.model_name.lower()
self.is_multimodal = True
if hasattr(self.model.config, 'image_size_aux'):
if not hasattr(self.image_processor, 'image_size_raw'):
self.image_processor.image_size_raw = self.image_processor.crop_size.copy()
self.image_processor.crop_size['height'] = self.model.config.image_size_aux
self.image_processor.crop_size['width'] = self.model.config.image_size_aux
self.image_processor.size['shortest_edge'] = self.model.config.image_size_aux
# ocr model
self.ocr_model = PaddleOCR(use_angle_cls=True, use_gpu=True, lang="ch")
# diffusion model
max_gpu_index = torch.cuda.device_count() - 1
device_last = torch.device(f'cuda:{max_gpu_index}')
print(torch.cuda.device_count(), '++++++', device_last)
self.pipe = StableDiffusionXLPipeline.from_pretrained(
"stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
).to(device=device_last)
if not no_register:
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=heart_beat_worker, args=(self,))
self.heart_beat_thread.start()
def register_to_controller(self):
logger.info("Register to controller")
url = self.controller_addr + "/register_worker"
data = {
"worker_name": self.worker_addr,
"check_heart_beat": True,
"worker_status": self.get_status()
}
r = requests.post(url, json=data)
assert r.status_code == 200
def send_heart_beat(self):
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
f"global_counter: {global_counter}")
url = self.controller_addr + "/receive_heart_beat"
while True:
try:
ret = requests.post(url, json={
"worker_name": self.worker_addr,
"queue_length": self.get_queue_length()}, timeout=30)
exist = ret.json()["exist"]
break
except requests.exceptions.RequestException as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)
if not exist:
self.register_to_controller()
def get_queue_length(self):
if model_semaphore is None:
return 0
else:
# return 1
# return args.limit_model_concurrency - model_semaphore._value + (0 if model_semaphore._waiters is None else len(model_semaphore._waiters))
return args.limit_model_concurrency - model_semaphore._value + (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
def get_status(self):
return {
"model_names": [self.model_name],
"speed": 1,
"queue_length": self.get_queue_length(),
}
def add_content(self, prompt, new_content):
if '[INST]' in prompt:
split_index = prompt.rfind(' [/INST]')
elif '<|im_end|>' in prompt:
split_index = prompt.rfind('<|im_end|>')
else:
split_index = prompt.rfind('###Assistant:')
left_prompt = prompt[:split_index]
right_prompt = prompt[split_index:]
prompt = left_prompt + new_content + right_prompt
return prompt
@torch.inference_mode()
def generate_stream(self, params):
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
prompt = params["prompt"]
ori_prompt = prompt
images = params.get("images", None)
gen_image = params.get("gen_image", False)
use_ocr = params.get("use_ocr", False)
num_image_tokens = 0
if gen_image:
prompt = self.add_content(prompt, ' <GEN>')
print(prompt)
if images is not None and len(images) > 0 and self.is_multimodal: # len(images) = 1
if len(images) > 0:
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
raise ValueError("Number of images does not match number of <image> tokens in prompt")
images = [load_image_from_base64(image) for image in images]
# add OCR tokens
if use_ocr:
str_in_image = ''
for image in images:
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format=image.format)
img_byte_arr = img_byte_arr.getvalue()
result = self.ocr_model.ocr(img_byte_arr, cls=True)
if result[0] is not None:
result = [res[1][0] for res in result[0] if res[1][1] > 0.1]
if len(result) > 0:
str_in_image += ', '.join(result)
# print('OCR Token: ' + str_in_image)
if len(str_in_image) > 0:
prompt = self.add_content(prompt, '\nReference OCR Token: ' + str_in_image + '\n')
image_tensor = process_images(images, image_processor, model.config)
image_grid = getattr(model.config, 'image_grid', 1)
if hasattr(model.config, 'image_size_aux'):
raw_shape = [image_processor.image_size_raw['height'] * image_grid,
image_processor.image_size_raw['width'] * image_grid]
image_tensor_aux = image_tensor
image_tensor = torch.nn.functional.interpolate(image_tensor,
size=raw_shape,
mode='bilinear',
align_corners=False) # # torch.Size([1, 3, 336, 336])
else:
image_tensor_aux = []
if image_grid >= 2:
raw_image = image_tensor.reshape(3,
image_grid,
image_processor.image_size_raw['height'],
image_grid,
image_processor.image_size_raw['width'])
raw_image = raw_image.permute(1, 3, 0, 2, 4)
raw_image = raw_image.reshape(-1, 3,
image_processor.image_size_raw['height'],
image_processor.image_size_raw['width'])
if getattr(model.config, 'image_global', False):
global_image = image_tensor
if len(global_image.shape) == 3:
global_image = global_image[None]
global_image = torch.nn.functional.interpolate(global_image,
size=[image_processor.image_size_raw['height'],
image_processor.image_size_raw['width']],
mode='bilinear',
align_corners=False)
# [image_crops, image_global]
raw_image = torch.cat([raw_image, global_image], dim=0)
image_tensor = raw_image.contiguous()
image_tensor = image_tensor.unsqueeze(0)
image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
image_tensor_aux = image_tensor_aux.to(self.model.device, dtype=torch.float16)
replace_token = DEFAULT_IMAGE_TOKEN
if getattr(self.model.config, 'mm_use_im_start_end', False):
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
else:
image_tensor = None
image_args = {"images": image_tensor, "images_aux": image_tensor_aux}
else:
image_tensor = None
image_args = {}
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
stop_str = params.get("stop", None)
do_sample = True if temperature > 0.001 else False
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=30)
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
if max_new_tokens < 1:
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
return
thread = Thread(target=model.generate, kwargs=dict(
inputs=input_ids,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True,
**image_args
))
thread.start()
generated_text = ori_prompt
for new_text in streamer:
generated_text += new_text
if generated_text.endswith(stop_str):
generated_text = generated_text[:-len(stop_str)]
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
torch.cuda.empty_cache()
if gen_image and "<h>" in generated_text and "</h>" in generated_text:
# common_neg_prompt = "blur, lowres, bad anatomy, bad hands, cropped, worst quality"
common_neg_prompt = "out of frame, lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
prompt = generated_text.split("<h>")[1].split("</h>")[0]
# yield json.dumps({"text": prompt, "error_code": 0}).encode() + b"\0"
output_img = self.pipe(prompt, negative_prompt=common_neg_prompt).images[0]
buffered = io.BytesIO()
output_img.save(buffered, format='JPEG')
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
torch.cuda.empty_cache()
generated_text = generated_text.split("<h>")[0] + '\n' + 'Prompt: ' + prompt + '\n'
yield json.dumps({"text": generated_text, "image": img_b64_str, "error_code": 0}).encode() + b"\0"
def generate_stream_gate(self, params):
try:
for x in self.generate_stream(params):
yield x
except ValueError as e:
print("Caught ValueError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError as e:
print("Caught torch.cuda.CudaError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
print("Caught Unknown Error", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
app = FastAPI()
def release_model_semaphore(fn=None):
model_semaphore.release()
if fn is not None:
fn()
@app.post("/worker_generate_stream")
async def generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
await model_semaphore.acquire()
worker.send_heart_beat()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_get_status")
async def get_status(request: Request):
return worker.get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str,
default="http://localhost:21002")
parser.add_argument("--controller-address", type=str,
default="http://localhost:21001")
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--model-name", type=str)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
parser.add_argument("--limit-model-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=1)
parser.add_argument("--no-register", action="store_true")
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--use-flash-attn", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
if args.multi_modal:
logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
worker = ModelWorker(args.controller_address,
args.worker_address,
worker_id,
args.no_register,
args.model_path,
args.model_base,
args.model_name,
args.load_8bit,
args.load_4bit,
args.device,
use_flash_attn=args.use_flash_attn)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
\ No newline at end of file
"""
Manually register workers.
Usage:
python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
"""
import argparse
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--controller-address", type=str)
parser.add_argument("--worker-name", type=str)
parser.add_argument("--check-heart-beat", action="store_true")
args = parser.parse_args()
url = args.controller_address + "/register_worker"
data = {
"worker_name": args.worker_name,
"check_heart_beat": args.check_heart_beat,
"worker_status": None,
}
r = requests.post(url, json=data)
assert r.status_code == 200
"""
A model worker executes the model.
"""
import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
import time
import threading
import uuid
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
import requests
import re
import uvicorn
from functools import partial
from mgm.constants import WORKER_HEART_BEAT_INTERVAL
from mgm.utils import (build_logger, server_error_msg,
pretty_print_semaphore)
from mgm.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square
from mgm.constants import DEFAULT_IMAGE_TOKEN
import sglang as sgl
from sglang.backend.runtime_endpoint import RuntimeEndpoint
GB = 1 << 30
worker_id = str(uuid.uuid4())[:6]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
global_counter = 0
model_semaphore = None
def heart_beat_worker(controller):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
controller.send_heart_beat()
@sgl.function
def pipeline(s, prompt, max_tokens):
for p in prompt:
if type(p) is str:
s += p
else:
s += sgl.image(p)
s += sgl.gen("response", max_tokens=max_tokens)
class ModelWorker:
def __init__(self, controller_addr, worker_addr, sgl_endpoint,
worker_id, no_register, model_name):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
# Select backend
backend = RuntimeEndpoint(sgl_endpoint)
sgl.set_default_backend(backend)
model_path = backend.model_info["model_path"]
if model_path.endswith("/"):
model_path = model_path[:-1]
if model_name is None:
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
self.model_name = model_paths[-2] + "_" + model_paths[-1]
else:
self.model_name = model_paths[-1]
else:
self.model_name = model_name
logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...")
if not no_register:
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=heart_beat_worker, args=(self,))
self.heart_beat_thread.start()
def register_to_controller(self):
logger.info("Register to controller")
url = self.controller_addr + "/register_worker"
data = {
"worker_name": self.worker_addr,
"check_heart_beat": True,
"worker_status": self.get_status()
}
r = requests.post(url, json=data)
assert r.status_code == 200
def send_heart_beat(self):
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
f"global_counter: {global_counter}")
url = self.controller_addr + "/receive_heart_beat"
while True:
try:
ret = requests.post(url, json={
"worker_name": self.worker_addr,
"queue_length": self.get_queue_length()}, timeout=5)
exist = ret.json()["exist"]
break
except requests.exceptions.RequestException as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)
if not exist:
self.register_to_controller()
def get_queue_length(self):
if model_semaphore is None:
return 0
else:
return args.limit_model_concurrency - model_semaphore._value + (len(
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
def get_status(self):
return {
"model_names": [self.model_name],
"speed": 1,
"queue_length": self.get_queue_length(),
}
async def generate_stream(self, params):
ori_prompt = prompt = params["prompt"]
images = params.get("images", None)
if images is not None and len(images) > 0:
if len(images) > 0:
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
raise ValueError("Number of images does not match number of <image> tokens in prompt")
images = [load_image_from_base64(image) for image in images]
# FIXME: for image-start/end token
# replace_token = DEFAULT_IMAGE_TOKEN
# if getattr(self.model.config, 'mm_use_im_start_end', False):
# replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
# prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN)
prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
prompt = []
for i in range(len(prompt_split)):
prompt.append(prompt_split[i])
if i < len(images):
prompt.append(images[i])
else:
prompt = [prompt]
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
# max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
stop_str = params.get("stop", None)
stop_str = [stop_str] if stop_str is not None else None
print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p})
state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True)
generated_text = ori_prompt
async for text_outputs in state.text_async_iter(var_name="response"):
generated_text += text_outputs
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
async def generate_stream_gate(self, params):
try:
async for x in self.generate_stream(params):
yield x
except ValueError as e:
print("Caught ValueError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
print("Caught Unknown Error", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
app = FastAPI()
def release_model_semaphore(fn=None):
model_semaphore.release()
if fn is not None:
fn()
@app.post("/worker_generate_stream")
async def generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
await model_semaphore.acquire()
worker.send_heart_beat()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_get_status")
async def get_status(request: Request):
return worker.get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str,
default="http://localhost:21002")
parser.add_argument("--controller-address", type=str,
default="http://localhost:21001")
parser.add_argument("--model-name", type=str)
parser.add_argument("--sgl-endpoint", type=str)
parser.add_argument("--limit-model-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=1)
parser.add_argument("--no-register", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
worker = ModelWorker(args.controller_address,
args.worker_address,
args.sgl_endpoint,
worker_id,
args.no_register,
args.model_name)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
\ No newline at end of file
import argparse
import json
import requests
from mgm.conversation import default_conversation
def main():
if args.worker_address:
worker_addr = args.worker_address
else:
controller_addr = args.controller_address
ret = requests.post(controller_addr + "/refresh_all_workers")
ret = requests.post(controller_addr + "/list_models")
models = ret.json()["models"]
models.sort()
print(f"Models: {models}")
ret = requests.post(controller_addr + "/get_worker_address",
json={"model": args.model_name})
worker_addr = ret.json()["address"]
print(f"worker_addr: {worker_addr}")
if worker_addr == "":
return
conv = default_conversation.copy()
conv.append_message(conv.roles[0], args.message)
prompt = conv.get_prompt()
headers = {"User-Agent": "LLaVA Client"}
pload = {
"model": args.model_name,
"prompt": prompt,
"max_new_tokens": args.max_new_tokens,
"temperature": 0.7,
"stop": conv.sep,
}
response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
json=pload, stream=True)
print(prompt.replace(conv.sep, "\n"), end="")
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"].split(conv.sep)[-1]
print(output, end="\r")
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
parser.add_argument("--worker-address", type=str)
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--max-new-tokens", type=int, default=32)
parser.add_argument("--message", type=str, default=
"Tell me a story with more than 1000 words.")
args = parser.parse_args()
main()
from typing import Optional, Tuple
import warnings
import torch
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, rotate_half
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn import __version__ as flash_attn_version
from flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_kvpacked_func,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
) # shape: (b, num_heads, s, head_dim)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
if past_key_value is not None:
# reuse k, v
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transform the data into the format required by flash attention
qkv = torch.stack([query_states, key_states, value_states], dim=2)
qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
max_s = q_len
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = output.view(bsz, q_len, -1)
else:
qkv = qkv.reshape(bsz, q_len, -1)
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
output_unpad = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len)
return self.o_proj(output), None, past_key_value
def apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids):
gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
gather_indices = gather_indices.repeat(
1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
)
bsz = gather_indices.shape[0]
cos, sin = (
torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
for x in cos_sin
)
q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
return q, k
def forward_inference(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
bsz, q_len, _ = hidden_states.size()
kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
q, k, v = (
op(hidden_states).view(bsz, q_len, nh, self.head_dim)
for op, nh in (
(self.q_proj, self.num_heads),
(self.k_proj, kv_heads),
(self.v_proj, kv_heads),
)
)
# shape: (b, s, num_heads, head_dim)
kv_seq_len = k.shape[1]
past_kv_len = 0
if past_key_value is not None:
past_kv_len = past_key_value[0].shape[2]
kv_seq_len += past_kv_len
cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids)
if past_key_value is not None:
assert (
flash_attn_version >= "2.1.0"
), "past_key_value support requires flash-attn >= 2.1.0"
# reuse k, v
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
if attention_mask is None:
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
bsz, q_len, -1
)
else:
q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
# We can skip concat and call unpad twice but seems better to call unpad only once.
kv, _, cu_k_lens, max_k = unpad_input(
torch.stack((k, v), dim=2), attention_mask
)
output_unpad = flash_attn_varlen_kvpacked_func(
q,
kv,
cu_q_lens,
cu_k_lens,
max_s,
max_k,
0.0,
softmax_scale=None,
causal=True,
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len)
return self.o_proj(output), None, past_key_value
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
def _prepare_decoder_attention_mask_inference(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
if past_key_values_length > 0 and attention_mask is not None:
attention_mask = torch.cat(
(
torch.full(
(input_shape[0], past_key_values_length),
True,
dtype=attention_mask.dtype,
device=attention_mask.device,
),
attention_mask,
),
dim=-1,
)
if attention_mask is not None and torch.all(attention_mask):
return None # This uses the faster call when training with full samples
def replace_llama_attn_with_flash_attn(inference=False):
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
if inference:
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask_inference
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_inference
else:
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
"""
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
"""
import logging
import math
from typing import Optional, Tuple
import torch
import transformers.models.llama.modeling_llama
from torch import nn
try:
import xformers.ops
except ImportError:
logging.error("xformers not found! Please install it before trying to use it.")
def replace_llama_attn_with_xformers_attn():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
import os
import torch
import torch.nn as nn
from torch.utils.data import Sampler
from transformers import Trainer
from transformers.trainer import (
is_sagemaker_mp_enabled,
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
logger,
)
from typing import List, Optional
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, 'no ignore status')
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
return to_return
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
return chunks
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
megabatch_size = world_size * batch_size
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
if len(additional_batch) > 0:
megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
lengths: Optional[List[int]] = None,
generator=None,
group_by_modality: bool = False,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.world_size = world_size
self.lengths = lengths
self.generator = generator
self.group_by_modality = group_by_modality
def __len__(self):
return len(self.lengths)
def __iter__(self):
if self.group_by_modality:
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
else:
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
return iter(indices)
class LLaVATrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
self.args.train_batch_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
lengths=lengths,
group_by_modality=True,
)
else:
return super()._get_train_sampler()
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
opt_model = self.model
if self.args.lr_multi is not None:
lr_multi_dict = {}
for _dict in self.args.lr_multi.split(','):
_key_val = _dict.split(':')
print("_key_val:", _key_val)
lr_multi_dict[_key_val[0]] = float(_key_val[1])
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
if self.args.mm_projector_lr is not None:
projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.mm_projector_lr,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"lr": self.args.mm_projector_lr,
},
]
elif self.args.lr_multi is not None:
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and not any([_key in n for _key in lr_multi_dict.keys()]))
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and not any([_key in n for _key in lr_multi_dict.keys()]))
],
"weight_decay": 0.0,
},
]
for _key in lr_multi_dict:
_key_decay = [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and _key in n)
]
_key_no_decay = [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and _key in n)
]
print("Params LR Change:", _key, "NUM:", len(_key_decay), len(_key_no_decay))
if len(_key_decay) > 0:
optimizer_grouped_parameters.append(
{
"params": _key_decay,
"lr": self.args.learning_rate * lr_multi_dict[_key],
"weight_decay": self.args.weight_decay,
},
)
if len(_key_no_decay) > 0:
optimizer_grouped_parameters.append(
{
"params": _key_no_decay,
"lr": self.args.learning_rate * lr_multi_dict[_key],
"weight_decay": 0.0,
},
)
else:
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
return self.optimizer
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
# Only save Adapter
keys_to_match = ['mm_projector', 'vision_resampler']
keys_to_match.extend(['vlm_att', 'vlm_uni'])
keys_to_match.extend(['vision_fpn', 'vision_stages', 'vision_tower'])
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(['embed_tokens', 'embed_in'])
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
else:
super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
pass
else:
super(LLaVATrainer, self)._save(output_dir, state_dict)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment